|
1 | 1 | import copy |
2 | | -import re |
3 | 2 |
|
4 | 3 | import numpy as np |
5 | 4 | import pytest |
@@ -143,55 +142,68 @@ def rewrite(g, level="fast_run"): |
143 | 142 | return g |
144 | 143 |
|
145 | 144 |
|
146 | | -def test_local_useless_slice(): |
147 | | - # test a simple matrix |
148 | | - x = matrix("x") |
149 | | - mode_excluding = get_default_mode().excluding( |
150 | | - "local_useless_slice", "local_mul_canonizer" |
151 | | - ) |
152 | | - mode_including = ( |
153 | | - get_default_mode() |
154 | | - .including("local_useless_slice") |
155 | | - .excluding("local_mul_canonizer") |
156 | | - ) |
| 145 | +class TestME: |
| 146 | + def local_useless_slice_tester(self): |
| 147 | + # test a simple matrix |
| 148 | + x = matrix("x") |
| 149 | + mode_excluding = get_mode("NUMBA").excluding( |
| 150 | + "local_useless_slice", "local_mul_canonizer" |
| 151 | + ) |
| 152 | + mode_including = ( |
| 153 | + get_mode("NUMBA") |
| 154 | + .including("local_useless_slice") |
| 155 | + .excluding("local_mul_canonizer") |
| 156 | + ) |
157 | 157 |
|
158 | | - # test with and without the useless slice |
159 | | - o = 2 * x[0, :] |
160 | | - f_excluding = function([x], o, mode=mode_excluding) |
161 | | - f_including = function([x], o, mode=mode_including) |
162 | | - rng = np.random.default_rng(utt.fetch_seed()) |
163 | | - test_inp = rng.integers(-10, 10, (4, 4)).astype("float32") |
164 | | - assert all(f_including(test_inp) == f_excluding(test_inp)) |
165 | | - # test to see if the slice is truly gone |
166 | | - apply_node = f_including.maker.fgraph.toposort()[0] |
167 | | - subtens = apply_node.op |
168 | | - assert not any(isinstance(idx, slice) for idx in subtens.idx_list) |
169 | | - |
170 | | - # Now test that the stack trace is copied over properly, |
171 | | - # before before and after rewriting. |
172 | | - assert check_stack_trace(f_excluding, ops_to_check="all") |
173 | | - assert check_stack_trace(f_including, ops_to_check="all") |
174 | | - |
175 | | - # test a 4d tensor |
176 | | - z = tensor4("z") |
177 | | - o2 = z[1, :, :, 1] |
178 | | - o3 = z[0, :, :, :] |
179 | | - f_including_check = function([z], o2, mode=mode_including) |
180 | | - f_including_check_apply = function([z], o3, mode=mode_including) |
181 | | - |
182 | | - # The rewrite shouldn't apply here |
183 | | - apply_node = f_including_check.maker.fgraph.toposort()[0] |
184 | | - subtens = apply_node.op |
185 | | - assert [isinstance(idx, slice) for idx in subtens.idx_list].count(True) == 2 |
186 | | - # But it should here |
187 | | - apply_node = f_including_check_apply.maker.fgraph.toposort()[0] |
188 | | - subtens = apply_node.op |
189 | | - assert not any(isinstance(idx, slice) for idx in subtens.idx_list) |
190 | | - |
191 | | - # Finally, test that the stack trace is copied over properly, |
192 | | - # before before and after rewriting. |
193 | | - assert check_stack_trace(f_including_check, ops_to_check=Subtensor) |
194 | | - assert check_stack_trace(f_including_check_apply, ops_to_check=Subtensor) |
| 158 | + # test with and without the useless slice |
| 159 | + o = 2 * x[0, :] |
| 160 | + f_excluding = function([x], o, mode=mode_excluding) |
| 161 | + f_including = function([x], o, mode=mode_including) |
| 162 | + rng = np.random.default_rng(utt.fetch_seed()) |
| 163 | + test_inp = rng.integers(-10, 10, (4, 4)).astype("float32") |
| 164 | + assert all(f_including(test_inp) == f_excluding(test_inp)) |
| 165 | + # test to see if the slice is truly gone |
| 166 | + apply_node = f_including.maker.fgraph.toposort()[0] |
| 167 | + subtens = apply_node.op |
| 168 | + assert not any(isinstance(idx, slice) for idx in subtens.idx_list) |
| 169 | + |
| 170 | + # Now test that the stack trace is copied over properly, |
| 171 | + # before before and after rewriting. |
| 172 | + assert check_stack_trace(f_excluding, ops_to_check="all") |
| 173 | + assert check_stack_trace(f_including, ops_to_check="all") |
| 174 | + |
| 175 | + # test a 4d tensor |
| 176 | + z = tensor4("z") |
| 177 | + o2 = z[1, :, :, 1] |
| 178 | + o3 = z[0, :, :, :] |
| 179 | + f_including_check = function([z], o2, mode=mode_including) |
| 180 | + f_including_check_apply = function([z], o3, mode=mode_including) |
| 181 | + |
| 182 | + # The rewrite shouldn't apply here |
| 183 | + apply_node = f_including_check.maker.fgraph.toposort()[0] |
| 184 | + subtens = apply_node.op |
| 185 | + assert [isinstance(idx, slice) for idx in subtens.idx_list].count(True) == 2 |
| 186 | + # But it should here |
| 187 | + apply_node = f_including_check_apply.maker.fgraph.toposort()[0] |
| 188 | + subtens = apply_node.op |
| 189 | + assert not any(isinstance(idx, slice) for idx in subtens.idx_list) |
| 190 | + |
| 191 | + # Finally, test that the stack trace is copied over properly, |
| 192 | + # before before and after rewriting. |
| 193 | + assert check_stack_trace(f_including_check, ops_to_check=Subtensor) |
| 194 | + assert check_stack_trace(f_including_check_apply, ops_to_check=Subtensor) |
| 195 | + |
| 196 | + def test_t0(self): |
| 197 | + import pytensor |
| 198 | + |
| 199 | + x = pt.vector("x") |
| 200 | + pytensor.function([x], x + 1) |
| 201 | + |
| 202 | + def test_t1(self): |
| 203 | + self.local_useless_slice_tester() |
| 204 | + |
| 205 | + def test_t2(self): |
| 206 | + self.local_useless_slice_tester() |
195 | 207 |
|
196 | 208 |
|
197 | 209 | def test_local_useless_fill(): |
@@ -307,9 +319,7 @@ def test_inconsistent_shared(self, shape_unsafe): |
307 | 319 | # Error raised by Alloc Op |
308 | 320 | with pytest.raises( |
309 | 321 | ValueError, |
310 | | - match=re.escape( |
311 | | - "cannot assign slice of shape (3, 7) from input of shape (6, 7)" |
312 | | - ), |
| 322 | + match=r"could not broadcast input array from shape \(3,7\) into shape \(6,7\)", |
313 | 323 | ): |
314 | 324 | f() |
315 | 325 |
|
@@ -1206,7 +1216,6 @@ def test_sum_bool_upcast(self): |
1206 | 1216 | f(5) |
1207 | 1217 |
|
1208 | 1218 |
|
1209 | | -@pytest.mark.xfail(reason="Numba does not support float16") |
1210 | 1219 | class TestLocalOptAllocF16(TestLocalOptAlloc): |
1211 | 1220 | dtype = "float16" |
1212 | 1221 |
|
|
0 commit comments