Skip to content

Commit ebb06ff

Browse files
committed
Group subtensor specify_shape lift tests in class
1 parent 92a22af commit ebb06ff

File tree

1 file changed

+100
-100
lines changed

1 file changed

+100
-100
lines changed

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 100 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -310,106 +310,106 @@ def test_local_subtensor_of_alloc():
310310
assert xval.__getitem__(slices).shape == val.shape
311311

312312

313-
@pytest.mark.parametrize(
314-
"x, s, idx, x_val, s_val",
315-
[
316-
(
317-
vector(),
318-
(iscalar(),),
319-
(1,),
320-
np.array([1, 2], dtype=config.floatX),
321-
np.array([2], dtype=np.int64),
322-
),
323-
(
324-
matrix(),
325-
(iscalar(), iscalar()),
326-
(1,),
327-
np.array([[1, 2], [3, 4]], dtype=config.floatX),
328-
np.array([2, 2], dtype=np.int64),
329-
),
330-
(
331-
matrix(),
332-
(iscalar(), iscalar()),
333-
(0,),
334-
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
335-
np.array([2, 3], dtype=np.int64),
336-
),
337-
(
338-
matrix(),
339-
(iscalar(), iscalar()),
340-
(1, 1),
341-
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
342-
np.array([2, 3], dtype=np.int64),
343-
),
344-
(
345-
tensor3(),
346-
(iscalar(), iscalar(), iscalar()),
347-
(-1,),
348-
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
349-
np.array([2, 3, 5], dtype=np.int64),
350-
),
351-
(
352-
tensor3(),
353-
(iscalar(), iscalar(), iscalar()),
354-
(-1, 0),
355-
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
356-
np.array([2, 3, 5], dtype=np.int64),
357-
),
358-
],
359-
)
360-
def test_local_subtensor_SpecifyShape_lift(x, s, idx, x_val, s_val):
361-
y = specify_shape(x, s)[idx]
362-
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape)
363-
364-
rewrites = RewriteDatabaseQuery(include=[None])
365-
no_rewrites_mode = Mode(optimizer=rewrites)
366-
367-
y_val_fn = function([x, *s], y, on_unused_input="ignore", mode=no_rewrites_mode)
368-
y_val = y_val_fn(*([x_val, *s_val]))
369-
370-
# This optimization should appear in the canonicalizations
371-
y_opt = rewrite_graph(y, clone=False)
372-
373-
if y.ndim == 0:
374-
# SpecifyShape should be removed altogether
375-
assert isinstance(y_opt.owner.op, Subtensor)
376-
assert y_opt.owner.inputs[0] is x
377-
else:
378-
assert isinstance(y_opt.owner.op, SpecifyShape)
379-
380-
y_opt_fn = function([x, *s], y_opt, on_unused_input="ignore")
381-
y_opt_val = y_opt_fn(*([x_val, *s_val]))
382-
383-
assert np.allclose(y_val, y_opt_val)
384-
385-
386-
@pytest.mark.parametrize(
387-
"x, s, idx",
388-
[
389-
(
390-
matrix(),
391-
(iscalar(), iscalar()),
392-
(slice(1, None),),
393-
),
394-
(
395-
matrix(),
396-
(iscalar(), iscalar()),
397-
(slicetype(),),
398-
),
399-
(
400-
matrix(),
401-
(iscalar(), iscalar()),
402-
(1, 0),
403-
),
404-
],
405-
)
406-
def test_local_subtensor_SpecifyShape_lift_fail(x, s, idx):
407-
y = specify_shape(x, s)[idx]
408-
409-
# This optimization should appear in the canonicalizations
410-
y_opt = rewrite_graph(y, clone=False)
411-
412-
assert not isinstance(y_opt.owner.op, SpecifyShape)
313+
class TestLocalSubtensorSpecifyShapeLift:
314+
@pytest.mark.parametrize(
315+
"x, s, idx, x_val, s_val",
316+
[
317+
(
318+
vector(),
319+
(iscalar(),),
320+
(1,),
321+
np.array([1, 2], dtype=config.floatX),
322+
np.array([2], dtype=np.int64),
323+
),
324+
(
325+
matrix(),
326+
(iscalar(), iscalar()),
327+
(1,),
328+
np.array([[1, 2], [3, 4]], dtype=config.floatX),
329+
np.array([2, 2], dtype=np.int64),
330+
),
331+
(
332+
matrix(),
333+
(iscalar(), iscalar()),
334+
(0,),
335+
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
336+
np.array([2, 3], dtype=np.int64),
337+
),
338+
(
339+
matrix(),
340+
(iscalar(), iscalar()),
341+
(1, 1),
342+
np.array([[1, 2, 3], [4, 5, 6]], dtype=config.floatX),
343+
np.array([2, 3], dtype=np.int64),
344+
),
345+
(
346+
tensor3(),
347+
(iscalar(), iscalar(), iscalar()),
348+
(-1,),
349+
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
350+
np.array([2, 3, 5], dtype=np.int64),
351+
),
352+
(
353+
tensor3(),
354+
(iscalar(), iscalar(), iscalar()),
355+
(-1, 0),
356+
np.arange(2 * 3 * 5, dtype=config.floatX).reshape((2, 3, 5)),
357+
np.array([2, 3, 5], dtype=np.int64),
358+
),
359+
],
360+
)
361+
def test_local_subtensor_SpecifyShape_lift(self, x, s, idx, x_val, s_val):
362+
y = specify_shape(x, s)[idx]
363+
assert isinstance(y.owner.inputs[0].owner.op, SpecifyShape)
364+
365+
rewrites = RewriteDatabaseQuery(include=[None])
366+
no_rewrites_mode = Mode(optimizer=rewrites)
367+
368+
y_val_fn = function([x, *s], y, on_unused_input="ignore", mode=no_rewrites_mode)
369+
y_val = y_val_fn(*([x_val, *s_val]))
370+
371+
# This optimization should appear in the canonicalizations
372+
y_opt = rewrite_graph(y, clone=False)
373+
374+
if y.ndim == 0:
375+
# SpecifyShape should be removed altogether
376+
assert isinstance(y_opt.owner.op, Subtensor)
377+
assert y_opt.owner.inputs[0] is x
378+
else:
379+
assert isinstance(y_opt.owner.op, SpecifyShape)
380+
381+
y_opt_fn = function([x, *s], y_opt, on_unused_input="ignore")
382+
y_opt_val = y_opt_fn(*([x_val, *s_val]))
383+
384+
assert np.allclose(y_val, y_opt_val)
385+
386+
@pytest.mark.parametrize(
387+
"x, s, idx",
388+
[
389+
(
390+
matrix(),
391+
(iscalar(), iscalar()),
392+
(slice(1, None),),
393+
),
394+
(
395+
matrix(),
396+
(iscalar(), iscalar()),
397+
(slicetype(),),
398+
),
399+
(
400+
matrix(),
401+
(iscalar(), iscalar()),
402+
(1, 0),
403+
),
404+
],
405+
)
406+
def test_local_subtensor_SpecifyShape_lift_fail(self, x, s, idx):
407+
y = specify_shape(x, s)[idx]
408+
409+
# This optimization should appear in the canonicalizations
410+
y_opt = rewrite_graph(y, clone=False)
411+
412+
assert not isinstance(y_opt.owner.op, SpecifyShape)
413413

414414

415415
class TestLocalSubtensorMakeVector:

0 commit comments

Comments
 (0)