Skip to content

Commit 1d1ba95

Browse files
authored
Merge pull request #2637 from devitocodes/cse-dequeue-fix
compiler: fix deque pop order
2 parents 8ec307c + 8e0a11e commit 1d1ba95

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

devito/passes/clusters/cse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,8 @@ def choose_element(queue, scheduled):
279279
first = sorted(tmps, key=lambda i: i.lhs.name).pop(0)
280280
queue.remove(first)
281281
else:
282-
first = queue.pop()
282+
first = sorted(queue, key=lambda i: exprs.index(i)).pop(0)
283+
queue.remove(first)
283284
return first
284285

285286
processed = dag.topological_sort(choose_element)

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
pip>=9.0.1
2-
numpy>=2,<2.3.0
2+
numpy>=2,<2.3.1
33
sympy>=1.12.1,<1.15
44
psutil>=5.1.0,<8.0
55
py-cpuinfo<10

tests/test_cse.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,26 @@ def test_advanced_algo(exprs, expected):
246246

247247
assert len(processed) == len(expected)
248248
assert all(str(i.rhs) == j for i, j in zip(processed, expected))
249+
250+
251+
def test_advanced_algo_order():
252+
"""
253+
Test that smartsort/advanced doesn't break equation order.
254+
"""
255+
grid = Grid((3, 3, 3))
256+
u = TimeFunction(name="u", grid=grid, space_order=2)
257+
v = TimeFunction(name="v", grid=grid, space_order=2)
258+
259+
eq0 = DummyEq(indexify(diffify(Eq(u.forward, u.dx).evaluate)))
260+
eq1 = DummyEq(indexify(diffify(Eq(v, u.dx).evaluate)))
261+
eq_b = DummyEq(indexify(diffify(Eq(v.forward, v + u.forward).evaluate)))
262+
263+
counter = generator()
264+
make = lambda _: CTemp(name='r%d' % counter(), dtype=np.float32).indexify()
265+
processed = _cse([eq0, eq1, eq_b], make, mode='advanced')
266+
267+
# Three input equation and 2 CTemps
268+
assert len(processed) == 5
269+
assert processed[0].lhs.name == 'r1'
270+
# eq_b has to be last
271+
assert processed[-1] == eq_b

0 commit comments

Comments
 (0)