Skip to content

Commit f117ecb

Browse files
committed
feature: alternative tensor typing syntax
1 parent ee65295 commit f117ecb

File tree

1 file changed

+29
-2
lines changed

1 file changed

+29
-2
lines changed

lantern/tensor.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,20 @@
11
from __future__ import annotations
2+
23
import numpy as np
34
import torch
5+
import torch._C
6+
7+
8+
class MetaTensor(torch._C._TensorMeta):
9+
def __getitem__(self, validate: Tensor) -> Tensor:
10+
return validate
411

512

6-
class Tensor(torch.Tensor):
13+
class Tensor(torch.Tensor, metaclass=MetaTensor):
14+
@classmethod
15+
def Validate(cls) -> Tensor:
16+
return cls
17+
718
@classmethod
819
def __get_validators__(cls):
920
yield cls.validate
@@ -234,6 +245,9 @@ def bool(cls) -> Tensor:
234245
return cls.dtype(torch.bool)
235246

236247

248+
Validate = Tensor
249+
250+
237251
def test_base_model():
238252
from pydantic import BaseModel
239253

@@ -251,8 +265,8 @@ def test_validate():
251265

252266

253267
def test_conversion():
254-
from pydantic import BaseModel
255268
import numpy as np
269+
from pydantic import BaseModel
256270

257271
class Test(BaseModel):
258272
numbers: Tensor.dims("N")
@@ -341,3 +355,16 @@ class Test(BaseModel):
341355

342356
with raises(ValueError):
343357
Test(numbers=[1, 2.2, 3.2])
358+
359+
360+
def test_alternative_syntax():
361+
from pydantic import BaseModel
362+
from pytest import raises
363+
364+
class Test(BaseModel):
365+
numbers: Tensor[Validate.ne(1)]
366+
367+
Test(numbers=[1.5, 2.2, 3.2])
368+
369+
with raises(ValueError):
370+
Test(numbers=[1, 2.2, 3.2])

0 commit comments

Comments
 (0)