|
2 | 2 | from sympy import Mod |
3 | 3 | import pytest |
4 | 4 |
|
5 | | -from devito import Grid, Eq, Function, TimeFunction, Operator, sin |
| 5 | +from devito import Grid, Eq, Function, TimeFunction, Operator, Min, sin |
6 | 6 | from devito.ir.equations import DummyEq |
7 | 7 | from devito.ir.iet import (Block, Expression, Callable, FindNodes, FindSections, |
8 | 8 | FindSymbols, IsPerfectIteration, Transformer, |
9 | | - Conditional, printAST, Iteration, MapNodes, Call) |
10 | | -from devito.types import SpaceDimension, Array |
| 9 | + Conditional, printAST, Iteration, MapNodes, Call, |
| 10 | + FindApplications) |
| 11 | +from devito.types import SpaceDimension, Array, Symbol |
11 | 12 |
|
12 | 13 |
|
13 | 14 | @pytest.fixture(scope="module") |
@@ -395,3 +396,11 @@ def test_map_nodes(block1): |
395 | 396 | processed = Transformer({iters[0]: Call(callback.name)}).visit(block1) |
396 | 397 |
|
397 | 398 | assert str(processed) == 'solver();' |
| 399 | + |
| 400 | + |
| 401 | +def test_find_apps_nested_calls(exprs, iters): |
| 402 | + s = Symbol(name='s') |
| 403 | + call = Call('foo', Call('bar', [Min(s, 1)])) |
| 404 | + block = iters[0](iters[1](exprs + [call])) |
| 405 | + |
| 406 | + assert len(FindApplications().visit(block)) == 1 |
0 commit comments