@@ -53,8 +53,6 @@ def compare_pytorch_and_py(
53
53
assert_fn: func, opt
54
54
Assert function used to check for equality between python and pytorch. If not
55
55
provided uses np.testing.assert_allclose
56
- must_be_device_array: Bool
57
- Checks if torch.device.type is cuda
58
56
59
57
60
58
"""
@@ -66,20 +64,19 @@ def compare_pytorch_and_py(
66
64
pytensor_torch_fn = function (fn_inputs , fgraph .outputs , mode = pytorch_mode )
67
65
pytorch_res = pytensor_torch_fn (* test_inputs )
68
66
69
- if must_be_device_array :
70
- if isinstance (pytorch_res , list ):
71
- assert all (isinstance (res , torch .Tensor ) for res in pytorch_res )
72
- else :
73
- assert pytorch_res .device .type == "cuda"
67
+ if isinstance (pytorch_res , list ):
68
+ assert all (isinstance (res , np .ndarray ) for res in pytorch_res )
69
+ else :
70
+ assert isinstance (pytorch_res , np .ndarray )
74
71
75
72
pytensor_py_fn = function (fn_inputs , fgraph .outputs , mode = py_mode )
76
73
py_res = pytensor_py_fn (* test_inputs )
77
74
78
75
if len (fgraph .outputs ) > 1 :
79
76
for pytorch_res_i , py_res_i in zip (pytorch_res , py_res , strict = True ):
80
- assert_fn (pytorch_res_i . detach (). cpu (). numpy () , py_res_i )
77
+ assert_fn (pytorch_res_i , py_res_i )
81
78
else :
82
- assert_fn (pytorch_res [0 ]. detach (). cpu (). numpy () , py_res [0 ])
79
+ assert_fn (pytorch_res [0 ], py_res [0 ])
83
80
84
81
return pytensor_torch_fn , pytorch_res
85
82
@@ -162,23 +159,23 @@ def test_shared(device):
162
159
pytensor_torch_fn = function ([], a , mode = "PYTORCH" )
163
160
pytorch_res = pytensor_torch_fn ()
164
161
165
- assert isinstance (pytorch_res , torch . Tensor )
162
+ assert isinstance (pytorch_res , np . ndarray )
166
163
assert isinstance (a .get_value (), np .ndarray )
167
- np .testing .assert_allclose (pytorch_res . cpu () , a .get_value ())
164
+ np .testing .assert_allclose (pytorch_res , a .get_value ())
168
165
169
166
pytensor_torch_fn = function ([], a * 2 , mode = "PYTORCH" )
170
167
pytorch_res = pytensor_torch_fn ()
171
168
172
- assert isinstance (pytorch_res , torch . Tensor )
169
+ assert isinstance (pytorch_res , np . ndarray )
173
170
assert isinstance (a .get_value (), np .ndarray )
174
- np .testing .assert_allclose (pytorch_res . cpu () , a .get_value () * 2 )
171
+ np .testing .assert_allclose (pytorch_res , a .get_value () * 2 )
175
172
176
173
new_a_value = np .array ([3 , 4 , 5 ], dtype = config .floatX )
177
174
a .set_value (new_a_value )
178
175
179
176
pytorch_res = pytensor_torch_fn ()
180
- assert isinstance (pytorch_res , torch . Tensor )
181
- np .testing .assert_allclose (pytorch_res . cpu () , new_a_value * 2 )
177
+ assert isinstance (pytorch_res , np . ndarray )
178
+ np .testing .assert_allclose (pytorch_res , new_a_value * 2 )
182
179
183
180
184
181
@pytest .mark .parametrize ("device" , ["cpu" , "cuda" ])
@@ -225,7 +222,7 @@ def test_alloc_and_empty():
225
222
fn = function ([dim1 ], out , mode = pytorch_mode )
226
223
res = fn (7 )
227
224
assert res .shape == (5 , 7 , 3 )
228
- assert res .dtype == torch .float32
225
+ assert res .dtype == np .float32
229
226
230
227
v = vector ("v" , shape = (3 ,), dtype = "float64" )
231
228
out = alloc (v , dim0 , dim1 , 3 )
0 commit comments