File tree Expand file tree Collapse file tree 2 files changed +22
-9
lines changed Expand file tree Collapse file tree 2 files changed +22
-9
lines changed Original file line number Diff line number Diff line change 1
1
from __future__ import annotations
2
+
2
3
import numpy as np
3
4
import torch
4
5
@@ -150,10 +151,15 @@ class InheritNumpy(cls):
150
151
@classmethod
151
152
def validate (cls , data ):
152
153
data = super ().validate (data )
153
- new_data = data .astype (dtype )
154
- if not np .allclose (data , new_data , equal_nan = True ):
155
- raise ValueError (f"Was unable to cast from { data .dtype } to { dtype } " )
156
- return new_data
154
+ if data .dtype == dtype :
155
+ return data
156
+ else :
157
+ new_data = data .astype (dtype )
158
+ if not np .allclose (data , new_data , equal_nan = True ):
159
+ raise ValueError (
160
+ f"Was unable to cast from { data .dtype } to { dtype } "
161
+ )
162
+ return new_data
157
163
158
164
return InheritNumpy
159
165
@@ -239,8 +245,8 @@ def test_validate():
239
245
240
246
241
247
def test_conversion ():
242
- from pydantic import BaseModel
243
248
import torch
249
+ from pydantic import BaseModel
244
250
245
251
class Test (BaseModel ):
246
252
numbers : Numpy .dims ("N" )
Original file line number Diff line number Diff line change @@ -177,10 +177,17 @@ class InheritTensor(cls):
177
177
@classmethod
178
178
def validate (cls , data ):
179
179
data = super ().validate (data )
180
- new_data = data .type (dtype )
181
- if not torch .allclose (data .float (), new_data .float (), equal_nan = True ):
182
- raise ValueError (f"Was unable to cast from { data .dtype } to { dtype } " )
183
- return new_data
180
+ if data .dtype == dtype :
181
+ return data
182
+ else :
183
+ new_data = data .type (dtype )
184
+ if not torch .allclose (
185
+ data .float (), new_data .float (), equal_nan = True
186
+ ):
187
+ raise ValueError (
188
+ f"Was unable to cast from { data .dtype } to { dtype } "
189
+ )
190
+ return new_data
184
191
185
192
return InheritTensor
186
193
You can’t perform that action at this time.
0 commit comments