Skip to content

Commit 0b445af

Browse files
committed
compiler: Emit arguments used to invoke kernels and add test for specialization with MPI
1 parent f82d22a commit 0b445af

File tree

2 files changed

+36
-18
lines changed

2 files changed

+36
-18
lines changed

devito/operator/operator.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -989,27 +989,22 @@ def apply(self, **kwargs):
989989
# Get items expected to be specialized
990990
specialize = as_tuple(kwargs.pop('specialize', []))
991991

992-
# In the case of specialization, arguments must be processed before
993-
# the operator can be compiled
994992
if specialize:
995993
# FIXME: Cannot cope with things like sizes/strides yet since it only
996994
# looks at the parameters
997995

998996
# Build the arguments list for specialization
999-
with self._profiler.timer_on('specialized-arguments-preprocess'):
997+
with self._profiler.timer_on('specialization'):
1000998
args = self.arguments(**kwargs)
1001-
with switch_log_level(comm=args.comm):
1002-
self._emit_args_profiling('specialized-arguments-preprocess')
999+
# Uses parameters here since Specializer needs {symbol: sympy value}
1000+
specialized_values = {p: sympify(args[p.name])
1001+
for p in self.parameters if p.name in specialize}
10031002

1004-
# Uses parameters here since Specializer needs {symbol: sympy value} mapper
1005-
specialized_values = {p: sympify(args[p.name])
1006-
for p in self.parameters if p.name in specialize}
1003+
op = Specializer(specialized_values).visit(self)
10071004

1008-
op = Specializer(specialized_values).visit(self)
1005+
with switch_log_level(comm=args.comm):
1006+
self._emit_args_profiling('specialization')
10091007

1010-
# TODO: Does this cause problems for profilers?
1011-
# FIXME: Need some way to inspect this Operator for testing
1012-
# FIXME: Perhaps this should use some separate method
10131008
unspecialized_kwargs = {k: v for k, v in kwargs.items()
10141009
if k not in specialize}
10151010

@@ -1025,9 +1020,7 @@ def apply(self, **kwargs):
10251020
with switch_log_level(comm=args.comm):
10261021
self._emit_args_profiling('arguments-preprocess')
10271022

1028-
args_string = ", ".join([f"{p.name}={args[p.name]}"
1029-
for p in self.parameters if p.is_Symbol])
1030-
debug(f"Invoking `{self.name}` with scalar arguments: {args_string}")
1023+
self._emit_arguments(args)
10311024

10321025
# Invoke kernel function with args
10331026
arg_values = [args[p.name] for p in self.parameters]
@@ -1064,6 +1057,28 @@ def _emit_args_profiling(self, tag=''):
10641057
tagstr = ' '.join(tag.split('-'))
10651058
debug(f"Operator `{self.name}` {tagstr}: {elapsed:.2f} s")
10661059

1060+
def _emit_arguments(self, args):
1061+
comm = args.comm
1062+
scalar_args = ", ".join([f"{p.name}={args[p.name]}"
1063+
for p in self.parameters
1064+
if p.is_Symbol])
1065+
1066+
rank = f"[rank{args.comm.Get_rank()}] " if comm is not MPI.COMM_NULL else ""
1067+
1068+
msg = f"* {rank}{scalar_args}"
1069+
1070+
with switch_log_level(comm=comm):
1071+
debug(f"Scalar arguments used to invoke `{self.name}`")
1072+
1073+
if comm is not MPI.COMM_NULL:
1074+
# With MPI enabled, we add one entry per rank
1075+
allmsg = comm.allgather(msg)
1076+
if comm.Get_rank() == 0:
1077+
for m in allmsg:
1078+
debug(m)
1079+
else:
1080+
debug(msg)
1081+
10671082
def _emit_build_profiling(self):
10681083
if not is_log_enabled_for('PERF'):
10691084
return

tests/test_specialization.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,15 @@ def test_basic(self, caplog, override):
228228

229229
# Ensure that the specialized operator was run
230230
assert all(s not in caplog.text for s in specialize)
231-
assert "specialized arguments preprocess" in caplog.text
231+
assert "specialization" in caplog.text
232232

233233
check = np.array(f.data[:])
234234
f.data[:] = 0
235235
op.apply(**kwargs)
236236

237-
assert np.all(check == f.data)
237+
assert np.all(check == f.data[:])
238238

239-
# Need to test specialization with MPI (both at)
239+
@pytest.mark.parallel(mode=[2, 4])
240+
@pytest.mark.parametrize('override', [False, True])
241+
def test_basic_mpi(self, caplog, mode, override):
242+
self.test_basic(caplog, override)

0 commit comments

Comments
 (0)