Skip to content

Commit b5d6f92

Browse files
committed
Move test_local_subtensor_of_dot to test_subtensor_lift
1 parent 4d539fa commit b5d6f92

File tree

2 files changed

+44
-46
lines changed

2 files changed

+44
-46
lines changed

tests/tensor/rewriting/test_math.py

Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1203,52 +1203,6 @@ def test_local_log_add_exp():
12031203
# TODO: test that the rewrite works in the presence of broadcasting.
12041204

12051205

1206-
def test_local_subtensor_of_dot():
1207-
m1 = matrix()
1208-
m2 = matrix()
1209-
d1 = np.arange(6).reshape((3, 2)).astype(config.floatX)
1210-
d2 = np.arange(8).reshape((2, 4)).astype(config.floatX) + 10
1211-
mode = get_default_mode().including("local_subtensor_of_dot")
1212-
1213-
def test_equality(a, b):
1214-
return a.shape == b.shape and np.allclose(a, b)
1215-
1216-
# [cst]
1217-
f = function([m1, m2], pytensor.tensor.dot(m1, m2)[1], mode=mode)
1218-
topo = f.maker.fgraph.toposort()
1219-
assert test_equality(f(d1, d2), np.dot(d1, d2)[1])
1220-
# DimShuffle happen in FAST_COMPILE
1221-
assert isinstance(topo[-1].op, CGemv | Gemv | DimShuffle)
1222-
1223-
# slice
1224-
f = function([m1, m2], pytensor.tensor.dot(m1, m2)[1:2], mode=mode)
1225-
topo = f.maker.fgraph.toposort()
1226-
assert test_equality(f(d1, d2), np.dot(d1, d2)[1:2])
1227-
assert isinstance(topo[-1].op, Dot22)
1228-
1229-
m1 = tensor3()
1230-
m2 = tensor3()
1231-
idx = iscalar()
1232-
d1 = np.arange(30).reshape(2, 5, 3).astype(config.floatX)
1233-
d2 = np.arange(72).reshape(4, 3, 6).astype(config.floatX) + 100
1234-
1235-
f = function(
1236-
[m1, m2, idx], pytensor.tensor.dot(m1, m2)[idx, 1:4, :, idx:], mode=mode
1237-
)
1238-
assert test_equality(f(d1, d2, 1), np.dot(d1, d2)[1, 1:4, :, 1:])
1239-
# if we return the gradients. We need to use same mode as before.
1240-
assert check_stack_trace(f, ops_to_check="last")
1241-
1242-
f = function(
1243-
[m1, m2, idx], pytensor.tensor.dot(m1, m2)[1:4, :, idx:, idx], mode=mode
1244-
)
1245-
assert test_equality(f(d1, d2, 1), np.dot(d1, d2)[1:4, :, 1:, 1])
1246-
1247-
# Now test that the stack trace is copied over properly,
1248-
# if we return the gradients. We need to use same mode as before.
1249-
assert check_stack_trace(f, ops_to_check="last")
1250-
1251-
12521206
def test_local_elemwise_sub_zeros():
12531207
scal = scalar()
12541208
vect = vector()

tests/tensor/rewriting/test_subtensor_lift.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
vector,
3838
)
3939
from pytensor.tensor.basic import MakeVector, concatenate, expand_dims, make_vector
40+
from pytensor.tensor.blas import Dot22, Gemv
41+
from pytensor.tensor.blas_c import CGemv
4042
from pytensor.tensor.elemwise import DimShuffle, Elemwise
4143
from pytensor.tensor.math import sum as pt_sum
4244
from pytensor.tensor.rewriting.subtensor_lift import (
@@ -178,6 +180,48 @@ def test_local_subtensor_of_elemwise_multiple_clients(self):
178180
assert local_subtensor_of_elemwise.transform(fgraph, out2.owner) is not None
179181

180182

183+
def test_local_subtensor_of_dot():
184+
m1 = matrix()
185+
m2 = matrix()
186+
d1 = np.arange(6).reshape((3, 2)).astype(config.floatX)
187+
d2 = np.arange(8).reshape((2, 4)).astype(config.floatX) + 10
188+
mode = get_default_mode().including("local_subtensor_of_dot")
189+
190+
def test_equality(a, b):
191+
return a.shape == b.shape and np.allclose(a, b)
192+
193+
# [cst]
194+
f = function([m1, m2], pt.dot(m1, m2)[1], mode=mode)
195+
topo = f.maker.fgraph.toposort()
196+
assert test_equality(f(d1, d2), np.dot(d1, d2)[1])
197+
# DimShuffle happen in FAST_COMPILE
198+
assert isinstance(topo[-1].op, CGemv | Gemv | DimShuffle)
199+
200+
# slice
201+
f = function([m1, m2], pt.dot(m1, m2)[1:2], mode=mode)
202+
topo = f.maker.fgraph.toposort()
203+
assert test_equality(f(d1, d2), np.dot(d1, d2)[1:2])
204+
assert isinstance(topo[-1].op, Dot22)
205+
206+
m1 = tensor3()
207+
m2 = tensor3()
208+
idx = iscalar()
209+
d1 = np.arange(30).reshape(2, 5, 3).astype(config.floatX)
210+
d2 = np.arange(72).reshape(4, 3, 6).astype(config.floatX) + 100
211+
212+
f = function([m1, m2, idx], pt.dot(m1, m2)[idx, 1:4, :, idx:], mode=mode)
213+
assert test_equality(f(d1, d2, 1), np.dot(d1, d2)[1, 1:4, :, 1:])
214+
# if we return the gradients. We need to use same mode as before.
215+
assert check_stack_trace(f, ops_to_check="last")
216+
217+
f = function([m1, m2, idx], pt.dot(m1, m2)[1:4, :, idx:, idx], mode=mode)
218+
assert test_equality(f(d1, d2, 1), np.dot(d1, d2)[1:4, :, 1:, 1])
219+
220+
# Now test that the stack trace is copied over properly,
221+
# if we return the gradients. We need to use same mode as before.
222+
assert check_stack_trace(f, ops_to_check="last")
223+
224+
181225
@pytest.mark.parametrize(
182226
"original_fn, expected_fn",
183227
[

0 commit comments

Comments
 (0)