Skip to content

Commit 8e0a11e

Browse files
committed
ci: add cse ordering test
1 parent a05d55a commit 8e0a11e

File tree

2 files changed

+25
-1
lines changed

2 files changed

+25
-1
lines changed

devito/passes/clusters/cse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,8 @@ def choose_element(queue, scheduled):
279279
first = sorted(tmps, key=lambda i: i.lhs.name).pop(0)
280280
queue.remove(first)
281281
else:
282-
first = queue.popleft()
282+
first = sorted(queue, key=lambda i: exprs.index(i)).pop(0)
283+
queue.remove(first)
283284
return first
284285

285286
processed = dag.topological_sort(choose_element)

tests/test_cse.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,26 @@ def test_advanced_algo(exprs, expected):
246246

247247
assert len(processed) == len(expected)
248248
assert all(str(i.rhs) == j for i, j in zip(processed, expected))
249+
250+
251+
def test_advanced_algo_order():
252+
"""
253+
Test that smartsort/advanced doesn't break equation order.
254+
"""
255+
grid = Grid((3, 3, 3))
256+
u = TimeFunction(name="u", grid=grid, space_order=2)
257+
v = TimeFunction(name="v", grid=grid, space_order=2)
258+
259+
eq0 = DummyEq(indexify(diffify(Eq(u.forward, u.dx).evaluate)))
260+
eq1 = DummyEq(indexify(diffify(Eq(v, u.dx).evaluate)))
261+
eq_b = DummyEq(indexify(diffify(Eq(v.forward, v + u.forward).evaluate)))
262+
263+
counter = generator()
264+
make = lambda _: CTemp(name='r%d' % counter(), dtype=np.float32).indexify()
265+
processed = _cse([eq0, eq1, eq_b], make, mode='advanced')
266+
267+
# Three input equation and 2 CTemps
268+
assert len(processed) == 5
269+
assert processed[0].lhs.name == 'r1'
270+
# eq_b has to be last
271+
assert processed[-1] == eq_b

0 commit comments

Comments
 (0)