Skip to content

Commit bf491cc

Browse files
FelixAbrahamssonFelixAbrahamsson
authored andcommitted
improve: faster dtype validate in the normal case
1 parent f117ecb commit bf491cc

File tree

2 files changed

+22
-9
lines changed

2 files changed

+22
-9
lines changed

lantern/numpy.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from __future__ import annotations
2+
23
import numpy as np
34
import torch
45

@@ -150,10 +151,15 @@ class InheritNumpy(cls):
150151
@classmethod
151152
def validate(cls, data):
152153
data = super().validate(data)
153-
new_data = data.astype(dtype)
154-
if not np.allclose(data, new_data, equal_nan=True):
155-
raise ValueError(f"Was unable to cast from {data.dtype} to {dtype}")
156-
return new_data
154+
if data.dtype == dtype:
155+
return data
156+
else:
157+
new_data = data.astype(dtype)
158+
if not np.allclose(data, new_data, equal_nan=True):
159+
raise ValueError(
160+
f"Was unable to cast from {data.dtype} to {dtype}"
161+
)
162+
return new_data
157163

158164
return InheritNumpy
159165

@@ -239,8 +245,8 @@ def test_validate():
239245

240246

241247
def test_conversion():
242-
from pydantic import BaseModel
243248
import torch
249+
from pydantic import BaseModel
244250

245251
class Test(BaseModel):
246252
numbers: Numpy.dims("N")

lantern/tensor.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -177,10 +177,17 @@ class InheritTensor(cls):
177177
@classmethod
178178
def validate(cls, data):
179179
data = super().validate(data)
180-
new_data = data.type(dtype)
181-
if not torch.allclose(data.float(), new_data.float(), equal_nan=True):
182-
raise ValueError(f"Was unable to cast from {data.dtype} to {dtype}")
183-
return new_data
180+
if data.dtype == dtype:
181+
return data
182+
else:
183+
new_data = data.type(dtype)
184+
if not torch.allclose(
185+
data.float(), new_data.float(), equal_nan=True
186+
):
187+
raise ValueError(
188+
f"Was unable to cast from {data.dtype} to {dtype}"
189+
)
190+
return new_data
184191

185192
return InheritTensor
186193

0 commit comments

Comments
 (0)