Skip to content

Commit a15bd58

Browse files
committed
[WIP] Test pt.push_index_to_materialized_nodes.
1 parent 5cd57bf commit a15bd58

File tree

1 file changed

+277
-0
lines changed

1 file changed

+277
-0
lines changed

test/test_transform.py

Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
from __future__ import annotations
2+
3+
4+
__copyright__ = "Copyright (C) 2026 Kaushik Kulkarni"
5+
6+
__license__ = """
7+
Permission is hereby granted, free of charge, to any person obtaining a copy
8+
of this software and associated documentation files (the "Software"), to deal
9+
in the Software without restriction, including without limitation the rights
10+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11+
copies of the Software, and to permit persons to whom the Software is
12+
furnished to do so, subject to the following conditions:
13+
14+
The above copyright notice and this permission notice shall be included in
15+
all copies or substantial portions of the Software.
16+
17+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
23+
THE SOFTWARE.
24+
"""
25+
26+
import numpy as np
27+
import pytest
28+
29+
import pytato as pt
30+
31+
32+
def test_indirection_pusher_0():
33+
x = pt.make_placeholder("x", 10)
34+
idx = pt.make_placeholder("idx", 1729, np.int32)
35+
y = x[idx]
36+
assert pt.push_index_to_materialized_nodes(y) == y
37+
38+
39+
def test_indirection_pusher_1():
40+
x = pt.make_placeholder("x", 10)
41+
idx = pt.make_placeholder("idx", 1729, np.int32)
42+
y = (2 * x)[idx]
43+
assert pt.push_index_to_materialized_nodes(y) == 2 * (x[idx])
44+
45+
46+
def test_indirection_pusher_2():
47+
x1 = pt.make_placeholder("x1", 10)
48+
x2 = pt.make_placeholder("x2", 10)
49+
idx = pt.make_placeholder("idx", 1729, np.int32)
50+
y = (x1 * x2)[idx]
51+
assert pt.push_index_to_materialized_nodes(y) == (x1[idx] + x2[idx])
52+
53+
54+
def test_indirection_pusher_3():
55+
x = pt.make_placeholder("x", 10)
56+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
57+
idx2 = pt.make_placeholder("idx2", 314, np.int32)
58+
assert pt.push_index_to_materialized_nodes(x[idx1][idx2]) == x[idx1[idx2]]
59+
60+
61+
def test_indirection_pusher_4():
62+
x = pt.make_placeholder("x", 10)
63+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
64+
idx2 = pt.make_placeholder("idx2", 314, np.int32)
65+
assert pt.push_index_to_materialized_nodes((2 * x[idx1])[idx2]) == 2 * x[idx1[idx2]]
66+
67+
68+
def test_indirection_pusher_5():
69+
x = pt.make_placeholder("x", (10, 10, 10, 10))
70+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
71+
idx2 = pt.make_placeholder("idx2", 1729, np.int32)
72+
idx3 = pt.make_placeholder("idx3", 314, np.int32)
73+
idx4 = pt.make_placeholder("idx4", 314, np.int32)
74+
idx5 = pt.make_placeholder("idx5", 314, np.int32)
75+
76+
assert (
77+
pt.push_index_to_materialized_nodes(x[:, idx1, idx2, :][idx3, idx4, idx5])
78+
== x[idx3, idx1[idx4], idx2[idx4], idx5]
79+
)
80+
81+
82+
def test_indirection_pusher_6():
83+
x = pt.make_placeholder("x", (10, 10, 10, 10))
84+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
85+
idx2 = pt.make_placeholder("idx2", 1729, np.int32)
86+
idx3 = pt.make_placeholder("idx3", 314, np.int32)
87+
idx4 = pt.make_placeholder("idx4", 314, np.int32)
88+
idx5 = pt.make_placeholder("idx5", 314, np.int32)
89+
90+
assert (
91+
pt.push_index_to_materialized_nodes(x[::2, idx1, idx2, ::3][idx3, idx4, idx5])
92+
== x[2 * idx3, idx1[idx4], idx2[idx4], 3 * idx5]
93+
)
94+
95+
96+
def test_indirection_pusher_7():
97+
x = pt.make_placeholder("x", (10, 10, 10))
98+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
99+
idx2 = pt.make_placeholder("idx2", 1729, np.int32)
100+
idx3 = pt.make_placeholder("idx1", 314, np.int32)
101+
idx4 = pt.make_placeholder("idx2", 314, np.int32)
102+
103+
assert (
104+
pt.push_index_to_materialized_nodes(x[idx1, :, idx2][idx3, idx4])
105+
== x[idx1[idx3], idx4, idx2[idx3]]
106+
)
107+
108+
109+
def test_indirection_pusher_8():
110+
x = pt.make_placeholder("x", (10, 10, 10))
111+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
112+
idx2 = pt.make_placeholder("idx2", 1729, np.int32)
113+
idx3 = pt.make_placeholder("idx1", 314, np.int32)
114+
idx4 = pt.make_placeholder("idx2", 314, np.int32)
115+
116+
assert (
117+
pt.push_index_to_materialized_nodes(x[idx1, ::2, idx2][idx3, idx4])
118+
== x[idx1[idx3], 2 * idx4, idx2[idx3]]
119+
)
120+
121+
122+
def test_indirection_pusher_9():
123+
x = pt.make_placeholder("x", (10, 10, 10, 10))
124+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
125+
idx2 = pt.make_placeholder("idx2", 1729, np.int32)
126+
idx3 = pt.make_placeholder("idx3", 314, np.int32)
127+
idx4 = pt.make_placeholder("idx4", 314, np.int32)
128+
129+
assert (
130+
pt.push_index_to_materialized_nodes(x[idx1, idx2, ::2, ::3][idx3, :, idx4])
131+
== x[idx1[idx3], idx2[idx3], ::2, 3 * idx4]
132+
)
133+
134+
135+
def test_indirection_pusher_10():
136+
x = pt.make_placeholder("x", (10, 10, 10, 10))
137+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
138+
idx2 = pt.make_placeholder("idx2", 1729, np.int32)
139+
idx3 = pt.make_placeholder("idx3", 314, np.int32)
140+
idx4 = pt.make_placeholder("idx4", 314, np.int32)
141+
# (_0, _1, _2) -> (idx1[_0], 2*_1, idx2[_0], _2)
142+
# (_0, _1) -> (idx3[_0], 3*_1, idx4[_0])
143+
# Net:
144+
# (_0, _1) -> (idx1[idx3[_0]], 6*_1, idx2[idx3[_0]], idx4[_0])
145+
146+
assert (
147+
pt.push_index_to_materialized_nodes(x[idx1, ::2, idx2][idx3, ::3, idx4])
148+
== x[idx1[idx3], ::6, idx2[idx3], idx4]
149+
)
150+
151+
152+
def test_indirection_pusher_11():
153+
x1 = pt.make_placeholder("x1", (10, 1, 10, 1))
154+
x2 = pt.make_placeholder("x2", (1, 10, 10, 10))
155+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
156+
idx2 = pt.make_placeholder("idx2", 1729, np.int32)
157+
y1 = (x1 + x2)[:, idx1, idx2, :]
158+
# (_0, _1, _2, _3) -> x1[_0, 0, _2, 0] + x2[0, _1, _2, _3]
159+
# (_0, _1, _2) -> (_0, idx1[_1], idx2[_1], _2])
160+
# Net ->
161+
# (_0, _1, _2) -> x1[_0, 0, idx2[_1], 0] + x2[0, idx1[_1], idx2[_1], _2]
162+
y2 = x1[:, 0, idx2, :] + x2[:, idx1, idx2, :]
163+
assert pt.push_index_to_materialized_nodes(y1) == y2
164+
165+
166+
def test_indirection_pusher_12():
167+
x1 = pt.make_placeholder("x1", (10, 1, 10, 1))
168+
x2 = pt.make_placeholder("x2", (1, 10, 10, 10))
169+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
170+
idx2 = pt.make_placeholder("idx2", 1729, np.int32)
171+
y1 = (x1 + x2)[idx1, :, idx2, :]
172+
# (_0, _1, _2, _3) -> x1[_0, 0, _2, 0] + x2[0, _1, _2, _3]
173+
# (_0, _1, _2) -> (idx1[_0], _1, idx2[_0], _2)
174+
# Net->
175+
# (_0, _1, _2) -> x1[idx1[_0], 0, idx2[_0], 0] + x2[0, _1, idx2[_0], _2]
176+
177+
y2 = x1[idx1, :, idx2, :] + x2[0, :, idx2, :]
178+
assert pt.push_index_to_materialized_nodes(y1) == y2
179+
180+
181+
@pytest.mark.xfail("axis permutation not yet supported.")
182+
def test_indirection_pusher_13():
183+
x = pt.make_placeholder("x", (10, 10, 10, 10))
184+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
185+
idx2 = pt.make_placeholder("idx2", 1729, np.int32)
186+
y1 = pt.transpose(x, (0, 2, 3, 1))[idx1, :idx2, :]
187+
# (_0, _1, _2, _3) -> (_0, _2, _3, _1)
188+
# (_0, _1, _2) -> (idx1[_0], _1, idx2[_0], _2)
189+
# Net->
190+
# (idx1[_0], idx2[_0], _2, _1)
191+
y2 = pt.transpose(x[idx1, idx2], (0, 1, 3, 2))
192+
assert pt.push_index_to_materialized_nodes(y1) == y2
193+
194+
195+
@pytest.mark.xfail("axis permutation not yet supported.")
196+
def test_indirection_pusher_14():
197+
x = pt.make_placeholder("x", (10, 10, 10, 10))
198+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
199+
idx2 = pt.make_placeholder("idx2", 1729, np.int32)
200+
y1 = pt.transpose(x, (0, 2, 3, 1))[idx1, :idx2, :]
201+
# (_0, _1, _2, _3) -> (_0, _2, _3, _1)
202+
# (_0, _1, _2) -> (idx1[_0], _1, idx2[_0], _2)
203+
# Net->
204+
# (idx1[_0], idx2[_0], _2, _1)
205+
y2 = pt.transpose(x[idx1, idx2], (0, 1, 3, 2))
206+
assert pt.push_index_to_materialized_nodes(y1) == y2
207+
208+
209+
def test_indirection_pusher_15():
210+
x = pt.make_placeholder("x", (10, 10))
211+
idx1 = pt.make_placeholder("idx1", 4, np.int32)
212+
idx2 = pt.make_placeholder("idx2", (10, 4), np.int32)
213+
idx3 = pt.make_placeholder("idx3", (1, 10, 10), np.int32)
214+
idx4 = pt.make_placeholder("idx4", (10, 10, 10), np.int32)
215+
assert (
216+
pt.push_index_to_materialized_nodes(x[idx1, idx2][idx3, idx4])
217+
== x[idx1[idx4], idx2[idx3]]
218+
)
219+
220+
221+
def test_indirection_pusher_16():
222+
x = pt.make_placeholder("x", (10, 10, 10))
223+
idx1 = pt.make_placeholder("idx1", (4, 1, 4), np.int32)
224+
idx2 = pt.make_placeholder("idx2", (10, 4), np.int32)
225+
idx3 = pt.make_placeholder("idx3", (10, 4), np.int32)
226+
idx4 = pt.make_placeholder("idx4", (10, 1), np.int32)
227+
idx5 = pt.make_placeholder("idx5", (10, 10), np.int32)
228+
assert (
229+
pt.push_index_to_materialized_nodes(x[idx1, idx2, idx3][idx4, 2:5, idx5])
230+
== x[idx1[idx4, :, idx5], idx2[2:5, idx5], idx3[idx5]]
231+
)
232+
233+
234+
def test_indirection_pusher_17():
235+
x = pt.make_placeholder("x", (10, 10, 10, 10))
236+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
237+
idx2 = pt.make_placeholder("idx2", 1729, np.int32)
238+
idx3 = pt.make_placeholder("idx1", 314, np.int32)
239+
idx4 = pt.make_placeholder("idx2", 314, np.int32)
240+
y1 = x[:, idx1, :, idx2][:, idx3, idx4]
241+
y2 = x[idx3, idx1.reshape(-1, 1), idx4, idx2.reshape(-1, 1)]
242+
assert pt.push_index_to_materialized_nodes(y1) == y2
243+
244+
245+
def test_indirection_pusher_18():
246+
x = pt.make_placeholder("x", (10, 10, 10))
247+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
248+
idx2 = pt.make_placeholder("idx2", 1729, np.int32)
249+
idx3 = pt.make_placeholder("idx1", 314, np.int32)
250+
idx4 = pt.make_placeholder("idx2", 314, np.int32)
251+
y1 = x[:, idx1, idx2, :][:, idx3, idx4]
252+
y2 = x[:, idx1[idx3], idx2[idx3], idx4]
253+
assert pt.push_index_to_materialized_nodes(y1) == y2
254+
255+
256+
def test_indirection_pusher_19():
257+
x = pt.make_placeholder("x", (10, 10, 10, 10, 10))
258+
idx1 = pt.make_placeholder("idx1", 1729, np.int32)
259+
idx2 = pt.make_placeholder("idx2", 1729, np.int32)
260+
idx3 = pt.make_placeholder("idx1", 314, np.int32)
261+
idx4 = pt.make_placeholder("idx2", 314, np.int32)
262+
y1 = x[:, idx1, :, idx2, :][:, :, idx3, idx4]
263+
y2 = pt.transpose(
264+
(1, 0, 2), x[:, idx1.reshape(-1, 1), idx3, idx2.reshape(-1, 1), idx4]
265+
)
266+
assert pt.push_index_to_materialized_nodes(y1) == y2
267+
268+
269+
def test_indirection_pusher_20():
270+
x = pt.make_placeholder("x", (10, 10, 10, 10))
271+
idx1 = pt.make_placeholder("idx1", (2718, 1729, 314), np.int32)
272+
idx2 = pt.make_placeholder("idx2", (1729, 314), np.int32)
273+
idx3 = pt.make_placeholder("idx3", 6, np.int32)
274+
idx4 = pt.make_placeholder("idx4", (10, 6), np.int32)
275+
y1 = x[:, idx1, :, idx2][:, :, :, idx3, idx4]
276+
y2 = x[idx3, pt.expand_dims(idx1, (3, 4)), idx4, pt.expand_dims(idx2, (2, 3))]
277+
assert pt.push_index_to_materialized_nodes(y1) == y2

0 commit comments

Comments
 (0)