Skip to content

Commit 8802126

Browse files
committed
fix: do not serialize to list on model dump
1 parent f30633e commit 8802126

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

lantern/functional_base.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,18 @@ def staticmethod(cls, fn):
3232
@classmethod
3333
def classmethod(cls, fn):
3434
return cls.setattr(fn.__name__, classmethod(fn))
35+
36+
37+
def test_replace_same_device():
38+
import torch
39+
40+
from .tensor import Tensor
41+
42+
class A(FunctionalBase):
43+
x: Tensor
44+
y: int
45+
46+
a = A(x=torch.tensor([1, 2, 3]).to("meta"), y=2)
47+
b = a.replace(y=2)
48+
49+
assert b.x.device == a.x.device

lantern/numpy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __get_pydantic_core_schema__(
5454
]
5555
),
5656
serialization=core_schema.plain_serializer_function_ser_schema(
57-
lambda instance: instance.tolist()
57+
lambda instance: instance
5858
),
5959
)
6060

lantern/tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ def __get_pydantic_core_schema__(
5454
]
5555
),
5656
serialization=core_schema.plain_serializer_function_ser_schema(
57-
lambda instance: instance.tolist()
57+
lambda instance: instance
5858
),
5959
)
6060

0 commit comments

Comments
 (0)