Skip to content

Commit aff0288

Browse files
committed
refactor!: pydantic 2 pattern for numpy and pytorch typing
BREAKING CHANGE
1 parent bd7e5fe commit aff0288

File tree

6 files changed

+374
-203
lines changed

6 files changed

+374
-203
lines changed

lantern/__init__.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,23 @@
1-
from .numpy_seed import numpy_seed
2-
from .star import star
3-
from .functional_base import FunctionalBase
4-
from .tensor import Tensor
5-
from .numpy import Numpy
61
from .epochs import Epochs
2+
from .functional_base import FunctionalBase
3+
from .lambda_module import Lambda
74
from .metric import Metric
85
from .metric_table import MetricTable
9-
from .lambda_module import Lambda
106
from .module_device import module_device
11-
from .module_train import module_train, module_eval
7+
from .module_train import module_eval, module_train
8+
from .numpy import Numpy
9+
from .numpy_seed import numpy_seed
10+
from .progress_bar import ProgressBar
1211
from .requires_grad import requires_grad, requires_nograd
1312
from .set_learning_rate import set_learning_rate
1413
from .set_seeds import set_seeds
14+
from .star import star
15+
from .tensor import Tensor
1516
from .worker_init_fn import worker_init_fn
16-
from .progress_bar import ProgressBar
1717

1818
try:
1919
from .early_stopping import EarlyStopping
2020
except ImportError:
2121
pass
22-
from .git_info import git_info
23-
24-
from pkg_resources import get_distribution, DistributionNotFound
2522

26-
try:
27-
__version__ = get_distribution("pytorch-lantern").version
28-
except DistributionNotFound:
29-
__version__ = "dev"
23+
from .git_info import git_info

lantern/functional_base.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,12 @@
1-
from pydantic import BaseModel, Extra
1+
from pydantic import BaseModel
22

33

44
class FunctionalBase(BaseModel):
5-
class Config:
6-
allow_mutation = False
7-
extra = Extra.forbid
8-
95
def map(self, fn, *args, **kwargs):
106
return fn(self, *args, **kwargs)
117

128
def replace(self, **kwargs):
13-
new_dict = self.dict()
9+
new_dict = self.model_dump()
1410
new_dict.update(**kwargs)
1511
return type(self)(**new_dict)
1612

lantern/numpy.py

Lines changed: 67 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,72 @@
11
from __future__ import annotations
22

3+
from typing import Any, List
4+
35
import numpy as np
46
import torch
7+
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
8+
from pydantic.json_schema import JsonSchemaValue
9+
from pydantic_core import core_schema
10+
from typing_extensions import Annotated
11+
12+
13+
def validate_from_list(values: List) -> np.ndarray:
14+
return np.array(values)
15+
516

17+
def validate_from_torch(tensor: torch.Tensor) -> np.ndarray:
18+
return tensor.numpy()
19+
20+
21+
class Numpy:
22+
@classmethod
23+
def __get_pydantic_core_schema__(
24+
cls,
25+
_source_type: Any,
26+
_handler: GetCoreSchemaHandler,
27+
) -> core_schema.CoreSchema:
28+
from_list_schema = core_schema.chain_schema(
29+
[
30+
core_schema.list_schema(),
31+
core_schema.no_info_plain_validator_function(validate_from_list),
32+
]
33+
)
34+
35+
from_torch_schema = core_schema.chain_schema(
36+
[
37+
core_schema.is_instance_schema(torch.Tensor),
38+
core_schema.no_info_plain_validator_function(validate_from_torch),
39+
]
40+
)
41+
42+
return core_schema.json_or_python_schema(
43+
json_schema=from_list_schema,
44+
python_schema=core_schema.chain_schema(
45+
[
46+
core_schema.union_schema(
47+
[
48+
core_schema.is_instance_schema(np.ndarray),
49+
from_list_schema,
50+
from_torch_schema,
51+
]
52+
),
53+
core_schema.no_info_plain_validator_function(cls.validate),
54+
]
55+
),
56+
serialization=core_schema.plain_serializer_function_ser_schema(
57+
lambda instance: instance.tolist()
58+
),
59+
)
660

7-
class Numpy(np.ndarray):
861
@classmethod
9-
def __get_validators__(cls):
10-
yield cls.validate
62+
def __get_pydantic_json_schema__(
63+
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
64+
) -> JsonSchemaValue:
65+
return handler(core_schema.list_schema())
1166

1267
@classmethod
13-
def validate(cls, data, config=None, field=None) -> np.ndarray:
14-
if isinstance(data, cls):
15-
return data.view(np.ndarray)
16-
elif isinstance(data, np.ndarray):
17-
return data
18-
elif isinstance(data, torch.Tensor):
19-
return data.numpy()
20-
else:
21-
return np.array(data)
68+
def validate(cls, data, config=None, field=None):
69+
return data
2270

2371
@classmethod
2472
def ndim(cls, ndim) -> Numpy:
@@ -229,9 +277,7 @@ def test_base_model():
229277
from pytest import raises
230278

231279
class Test(BaseModel):
232-
images: Numpy.dims("NCHW")
233-
234-
Test(images=np.ones((10, 3, 32, 32)))
280+
images: Annotated[np.ndarray, Numpy.dims("NCHW")]
235281

236282
with raises(ValueError):
237283
Test(images=np.ones((10, 3, 32)))
@@ -249,7 +295,7 @@ def test_conversion():
249295
from pydantic import BaseModel
250296

251297
class Test(BaseModel):
252-
numbers: Numpy.dims("N")
298+
numbers: Annotated[np.ndarray, Numpy.dims("N")]
253299

254300
Test(numbers=[1.1, 2.1, 3.1])
255301
Test(numbers=torch.tensor([1.1, 2.1, 3.1]))
@@ -270,15 +316,15 @@ def test_dtype():
270316
from pytest import raises
271317

272318
class Test(BaseModel):
273-
numbers: Numpy.uint8()
319+
numbers: Annotated[np.ndarray, Numpy.uint8()]
274320

275321
Test(numbers=[1, 2, 3])
276322

277323
with raises(ValueError):
278324
Test(numbers=[1.5, 2.2, 3.2])
279325

280326
class TestBool(BaseModel):
281-
flags: Numpy.bool()
327+
flags: Annotated[np.ndarray, Numpy.bool()]
282328

283329
TestBool(flags=[True, False, True])
284330

@@ -291,7 +337,7 @@ def test_from_torch():
291337
from pydantic import BaseModel
292338

293339
class Test(BaseModel):
294-
numbers: Numpy
340+
numbers: Annotated[np.ndarray, Numpy]
295341

296342
numbers = torch.tensor([1, 2, 3])
297343
numpy_numbers = Test(numbers=numbers).numbers
@@ -305,7 +351,7 @@ def test_between():
305351
from pytest import raises
306352

307353
class Test(BaseModel):
308-
numbers: Numpy.between(1, 3.5)
354+
numbers: Annotated[np.ndarray, Numpy.between(1, 3.5)]
309355

310356
Test(numbers=[1.5, 2.2, 3.2])
311357

@@ -318,7 +364,7 @@ def test_gt():
318364
from pytest import raises
319365

320366
class Test(BaseModel):
321-
numbers: Numpy.gt(1)
367+
numbers: Annotated[np.ndarray, Numpy.gt(1)]
322368

323369
Test(numbers=[1.5, 2.2, 3.2])
324370

lantern/tensor.py

Lines changed: 69 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,72 @@
11
from __future__ import annotations
22

3+
from typing import Any, List
4+
35
import numpy as np
46
import torch
5-
import torch._C
7+
from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler
8+
from pydantic.json_schema import JsonSchemaValue
9+
from pydantic_core import core_schema
10+
from typing_extensions import Annotated
11+
612

13+
def validate_from_list(values: List) -> torch.Tensor:
14+
return torch.tensor(values)
715

8-
class MetaTensor(torch._C._TensorMeta):
9-
def __getitem__(self, validate: Tensor) -> Tensor:
10-
return validate
1116

17+
def validate_from_numpy(array: np.ndarray) -> torch.Tensor:
18+
return torch.from_numpy(array)
1219

13-
class Tensor(torch.Tensor, metaclass=MetaTensor):
20+
21+
class Tensor:
1422
@classmethod
15-
def Validate(cls) -> Tensor:
16-
return cls
23+
def __get_pydantic_core_schema__(
24+
cls,
25+
_source_type: Any,
26+
_handler: GetCoreSchemaHandler,
27+
) -> core_schema.CoreSchema:
28+
from_list_schema = core_schema.chain_schema(
29+
[
30+
core_schema.list_schema(),
31+
core_schema.no_info_plain_validator_function(validate_from_list),
32+
]
33+
)
34+
35+
from_numpy_schema = core_schema.chain_schema(
36+
[
37+
core_schema.is_instance_schema(np.ndarray),
38+
core_schema.no_info_plain_validator_function(validate_from_numpy),
39+
]
40+
)
41+
42+
return core_schema.json_or_python_schema(
43+
json_schema=from_list_schema,
44+
python_schema=core_schema.chain_schema(
45+
[
46+
core_schema.union_schema(
47+
[
48+
core_schema.is_instance_schema(torch.Tensor),
49+
from_list_schema,
50+
from_numpy_schema,
51+
]
52+
),
53+
core_schema.no_info_plain_validator_function(cls.validate),
54+
]
55+
),
56+
serialization=core_schema.plain_serializer_function_ser_schema(
57+
lambda instance: instance.tolist()
58+
),
59+
)
1760

1861
@classmethod
19-
def __get_validators__(cls):
20-
yield cls.validate
62+
def __get_pydantic_json_schema__(
63+
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
64+
) -> JsonSchemaValue:
65+
return handler(core_schema.list_schema())
2166

2267
@classmethod
23-
def validate(cls, data, config=None, field=None) -> torch.Tensor:
24-
if isinstance(data, cls):
25-
return torch.tensor(data)
26-
elif isinstance(data, torch.Tensor):
27-
return data
28-
elif isinstance(data, np.ndarray):
29-
return torch.from_numpy(data)
30-
else:
31-
return torch.as_tensor(data)
68+
def validate(cls, data, config=None, field=None):
69+
return data
3270

3371
@classmethod
3472
def ndim(cls, ndim) -> Tensor:
@@ -252,14 +290,11 @@ def bool(cls) -> Tensor:
252290
return cls.dtype(torch.bool)
253291

254292

255-
Validate = Tensor
256-
257-
258293
def test_base_model():
259294
from pydantic import BaseModel
260295

261296
class Test(BaseModel):
262-
tensor: Tensor.dims("NCHW").float()
297+
tensor: Annotated[torch.Tensor, Tensor.dims("NCHW").float()]
263298

264299
Test(tensor=torch.ones(10, 3, 32, 32))
265300

@@ -276,8 +311,8 @@ def test_conversion():
276311
from pydantic import BaseModel
277312

278313
class Test(BaseModel):
279-
numbers: Tensor.dims("N")
280-
numbers2: Tensor.dims("N")
314+
numbers: Annotated[torch.Tensor, Tensor.dims("N")]
315+
numbers2: Annotated[torch.Tensor, Tensor.dims("N")]
281316

282317
Test(
283318
numbers=[1.1, 2.1, 3.1],
@@ -300,15 +335,15 @@ def test_dtype():
300335
from pytest import raises
301336

302337
class Test(BaseModel):
303-
numbers: Tensor.uint8()
338+
numbers: Annotated[torch.Tensor, Tensor.uint8()]
304339

305340
Test(numbers=[1, 2, 3])
306341

307342
with raises(ValueError):
308343
Test(numbers=[1.5, 2.2, 3.2])
309344

310345
class TestBool(BaseModel):
311-
flags: Tensor.bool()
346+
flags: Annotated[torch.Tensor, Tensor.bool()]
312347

313348
TestBool(flags=[True, False, True])
314349

@@ -320,7 +355,7 @@ def test_device():
320355
from pydantic import BaseModel
321356

322357
class Test(BaseModel):
323-
numbers: Tensor.float().cpu()
358+
numbers: Annotated[torch.Tensor, Tensor.float().cpu()]
324359

325360
Test(numbers=[1, 2, 3])
326361

@@ -329,7 +364,7 @@ def test_from_numpy():
329364
from pydantic import BaseModel
330365

331366
class Test(BaseModel):
332-
numbers: Tensor
367+
numbers: Annotated[torch.Tensor, Tensor]
333368

334369
numbers = np.array([1, 2, 3])
335370
torch_numbers = Test(numbers=numbers).numbers
@@ -343,7 +378,7 @@ def test_ge():
343378
from pytest import raises
344379

345380
class Test(BaseModel):
346-
numbers: Tensor.ge(0)
381+
numbers: Annotated[torch.Tensor, Tensor.ge(0)]
347382

348383
Test(numbers=[1.5, 2.2, 3.2])
349384

@@ -356,22 +391,22 @@ def test_ne():
356391
from pytest import raises
357392

358393
class Test(BaseModel):
359-
numbers: Tensor.ne(1)
394+
numbers: Annotated[torch.Tensor, Tensor.ne(1)]
360395

361396
Test(numbers=[1.5, 2.2, 3.2])
362397

363398
with raises(ValueError):
364399
Test(numbers=[1, 2.2, 3.2])
365400

366401

367-
def test_alternative_syntax():
402+
def test_shorthand_syntax():
368403
from pydantic import BaseModel
369404
from pytest import raises
370405

371406
class Test(BaseModel):
372-
numbers: Tensor[Validate.ne(1)]
407+
numbers: Tensor.dims("N").float()
373408

374-
Test(numbers=[1.5, 2.2, 3.2])
409+
Test(numbers=[1.5, 2.2, 3.2]).numbers
375410

376411
with raises(ValueError):
377-
Test(numbers=[1, 2.2, 3.2])
412+
Test(numbers=[[1, 2.2, 3.2], [1, 2, 3]])

0 commit comments

Comments
 (0)