Skip to content

Commit ed84a8a

Browse files
committed
.new_test
1 parent 5907874 commit ed84a8a

File tree

1 file changed

+62
-53
lines changed

1 file changed

+62
-53
lines changed

tests/tensor/rewriting/test_basic.py

Lines changed: 62 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import copy
2-
import re
32

43
import numpy as np
54
import pytest
@@ -143,55 +142,68 @@ def rewrite(g, level="fast_run"):
143142
return g
144143

145144

146-
def test_local_useless_slice():
147-
# test a simple matrix
148-
x = matrix("x")
149-
mode_excluding = get_default_mode().excluding(
150-
"local_useless_slice", "local_mul_canonizer"
151-
)
152-
mode_including = (
153-
get_default_mode()
154-
.including("local_useless_slice")
155-
.excluding("local_mul_canonizer")
156-
)
145+
class TestME:
146+
def local_useless_slice_tester(self):
147+
# test a simple matrix
148+
x = matrix("x")
149+
mode_excluding = get_mode("NUMBA").excluding(
150+
"local_useless_slice", "local_mul_canonizer"
151+
)
152+
mode_including = (
153+
get_mode("NUMBA")
154+
.including("local_useless_slice")
155+
.excluding("local_mul_canonizer")
156+
)
157157

158-
# test with and without the useless slice
159-
o = 2 * x[0, :]
160-
f_excluding = function([x], o, mode=mode_excluding)
161-
f_including = function([x], o, mode=mode_including)
162-
rng = np.random.default_rng(utt.fetch_seed())
163-
test_inp = rng.integers(-10, 10, (4, 4)).astype("float32")
164-
assert all(f_including(test_inp) == f_excluding(test_inp))
165-
# test to see if the slice is truly gone
166-
apply_node = f_including.maker.fgraph.toposort()[0]
167-
subtens = apply_node.op
168-
assert not any(isinstance(idx, slice) for idx in subtens.idx_list)
169-
170-
# Now test that the stack trace is copied over properly,
171-
# before before and after rewriting.
172-
assert check_stack_trace(f_excluding, ops_to_check="all")
173-
assert check_stack_trace(f_including, ops_to_check="all")
174-
175-
# test a 4d tensor
176-
z = tensor4("z")
177-
o2 = z[1, :, :, 1]
178-
o3 = z[0, :, :, :]
179-
f_including_check = function([z], o2, mode=mode_including)
180-
f_including_check_apply = function([z], o3, mode=mode_including)
181-
182-
# The rewrite shouldn't apply here
183-
apply_node = f_including_check.maker.fgraph.toposort()[0]
184-
subtens = apply_node.op
185-
assert [isinstance(idx, slice) for idx in subtens.idx_list].count(True) == 2
186-
# But it should here
187-
apply_node = f_including_check_apply.maker.fgraph.toposort()[0]
188-
subtens = apply_node.op
189-
assert not any(isinstance(idx, slice) for idx in subtens.idx_list)
190-
191-
# Finally, test that the stack trace is copied over properly,
192-
# before before and after rewriting.
193-
assert check_stack_trace(f_including_check, ops_to_check=Subtensor)
194-
assert check_stack_trace(f_including_check_apply, ops_to_check=Subtensor)
158+
# test with and without the useless slice
159+
o = 2 * x[0, :]
160+
f_excluding = function([x], o, mode=mode_excluding)
161+
f_including = function([x], o, mode=mode_including)
162+
rng = np.random.default_rng(utt.fetch_seed())
163+
test_inp = rng.integers(-10, 10, (4, 4)).astype("float32")
164+
assert all(f_including(test_inp) == f_excluding(test_inp))
165+
# test to see if the slice is truly gone
166+
apply_node = f_including.maker.fgraph.toposort()[0]
167+
subtens = apply_node.op
168+
assert not any(isinstance(idx, slice) for idx in subtens.idx_list)
169+
170+
# Now test that the stack trace is copied over properly,
171+
# before before and after rewriting.
172+
assert check_stack_trace(f_excluding, ops_to_check="all")
173+
assert check_stack_trace(f_including, ops_to_check="all")
174+
175+
# test a 4d tensor
176+
z = tensor4("z")
177+
o2 = z[1, :, :, 1]
178+
o3 = z[0, :, :, :]
179+
f_including_check = function([z], o2, mode=mode_including)
180+
f_including_check_apply = function([z], o3, mode=mode_including)
181+
182+
# The rewrite shouldn't apply here
183+
apply_node = f_including_check.maker.fgraph.toposort()[0]
184+
subtens = apply_node.op
185+
assert [isinstance(idx, slice) for idx in subtens.idx_list].count(True) == 2
186+
# But it should here
187+
apply_node = f_including_check_apply.maker.fgraph.toposort()[0]
188+
subtens = apply_node.op
189+
assert not any(isinstance(idx, slice) for idx in subtens.idx_list)
190+
191+
# Finally, test that the stack trace is copied over properly,
192+
# before before and after rewriting.
193+
assert check_stack_trace(f_including_check, ops_to_check=Subtensor)
194+
assert check_stack_trace(f_including_check_apply, ops_to_check=Subtensor)
195+
196+
def test_t0(self):
197+
import pytensor
198+
199+
x = pt.vector("x")
200+
pytensor.function([x], x + 1)
201+
202+
def test_t1(self):
203+
self.local_useless_slice_tester()
204+
205+
def test_t2(self):
206+
self.local_useless_slice_tester()
195207

196208

197209
def test_local_useless_fill():
@@ -307,9 +319,7 @@ def test_inconsistent_shared(self, shape_unsafe):
307319
# Error raised by Alloc Op
308320
with pytest.raises(
309321
ValueError,
310-
match=re.escape(
311-
"cannot assign slice of shape (3, 7) from input of shape (6, 7)"
312-
),
322+
match=r"could not broadcast input array from shape \(3,7\) into shape \(6,7\)",
313323
):
314324
f()
315325

@@ -1206,7 +1216,6 @@ def test_sum_bool_upcast(self):
12061216
f(5)
12071217

12081218

1209-
@pytest.mark.xfail(reason="Numba does not support float16")
12101219
class TestLocalOptAllocF16(TestLocalOptAlloc):
12111220
dtype = "float16"
12121221

0 commit comments

Comments
 (0)