Skip to content

Commit 35991c2

Browse files
Merge pull request #124 from inducer/refactor-pytato-compile-exec
Refactor PytatoActx.compile for usability by distributed
2 parents 33884ae + 36fd05c commit 35991c2

File tree

1 file changed

+85
-70
lines changed

1 file changed

+85
-70
lines changed

arraycontext/impl/pytato/compile.py

Lines changed: 85 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,49 @@ class LazilyCompilingFunctionCaller:
187187
program_cache: Dict["PMap[Tuple[Any, ...], AbstractInputDescriptor]",
188188
"CompiledFunction"] = field(default_factory=lambda: {})
189189

190+
def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays):
191+
from pytato.target.loopy import BoundPyOpenCLProgram
192+
193+
import loopy as lp
194+
195+
with ProcessLogger(logger, "transform_dag"):
196+
pt_dict_of_named_arrays = self.actx.transform_dag(
197+
pt.make_dict_of_named_arrays(dict_of_named_arrays))
198+
199+
with ProcessLogger(logger, "generate_loopy"):
200+
pytato_program = pt.generate_loopy(pt_dict_of_named_arrays,
201+
options=lp.Options(
202+
return_dict=True,
203+
no_numpy=True),
204+
cl_device=self.actx.queue.device)
205+
assert isinstance(pytato_program, BoundPyOpenCLProgram)
206+
207+
with ProcessLogger(logger, "transform_loopy_program"):
208+
209+
pytato_program = (pytato_program
210+
.with_transformed_program(
211+
lambda x: x.with_kernel(
212+
x.default_entrypoint
213+
.tagged(FromArrayContextCompile()))))
214+
215+
pytato_program = (pytato_program
216+
.with_transformed_program(self
217+
.actx
218+
.transform_loopy_program))
219+
220+
return pytato_program
221+
222+
def _dag_to_compiled_func(self, dict_of_named_arrays,
223+
input_id_to_name_in_program, output_id_to_name_in_program,
224+
output_template):
225+
pytato_program = self._dag_to_transformed_loopy_prg(dict_of_named_arrays)
226+
227+
return CompiledFunction(
228+
self.actx, pytato_program,
229+
input_id_to_name_in_program=input_id_to_name_in_program,
230+
output_id_to_name_in_program=output_id_to_name_in_program,
231+
output_template=output_template)
232+
190233
def __call__(self, *args: Any, **kwargs: Any) -> Any:
191234
"""
192235
Returns the result of :attr:`~LazilyCompilingFunctionCaller.f`'s
@@ -197,8 +240,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
197240
:attr:`~LazilyCompilingFunctionCaller.f` with *args* in a lazy-sense.
198241
The intermediary pytato DAG for *args* is memoized in *self*.
199242
"""
200-
from pytato.target.loopy import BoundPyOpenCLProgram
201-
202243
arg_id_to_arg, arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr(
203244
args, kwargs)
204245

@@ -210,74 +251,70 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
210251
return compiled_f(arg_id_to_arg)
211252

212253
dict_of_named_arrays = {}
213-
# output_naming_map: result id to name of the named array in the
214-
# generated pytato DAG.
215-
output_naming_map = {}
216-
# input_naming_map: argument id to placeholder name in the generated
217-
# pytato DAG.
218-
input_naming_map = {
254+
output_id_to_name_in_program = {}
255+
input_id_to_name_in_program = {
219256
arg_id: f"_actx_in_{_ary_container_key_stringifier(arg_id)}"
220257
for arg_id in arg_id_to_arg}
221258

222-
outputs = self.f(*[_get_f_placeholder_args(arg, iarg, input_naming_map)
223-
for iarg, arg in enumerate(args)],
224-
**{kw: _get_f_placeholder_args(arg, kw, input_naming_map)
225-
for kw, arg in kwargs.items()})
259+
output_template = self.f(
260+
*[_get_f_placeholder_args(arg, iarg, input_id_to_name_in_program)
261+
for iarg, arg in enumerate(args)],
262+
**{kw: _get_f_placeholder_args(arg, kw, input_id_to_name_in_program)
263+
for kw, arg in kwargs.items()})
226264

227-
if not is_array_container_type(outputs.__class__):
265+
if not is_array_container_type(output_template.__class__):
228266
# TODO: We could possibly just short-circuit this interface if the
229267
# returned type is a scalar. Not sure if it's worth it though.
230268
raise NotImplementedError(
231269
f"Function '{self.f.__name__}' to be compiled "
232270
"did not return an array container, but an instance of "
233-
f"'{outputs.__class__}' instead.")
271+
f"'{output_template.__class__}' instead.")
234272

235273
def _as_dict_of_named_arrays(keys, ary):
236274
name = "_pt_out_" + "_".join(str(key)
237275
for key in keys)
238-
output_naming_map[keys] = name
276+
output_id_to_name_in_program[keys] = name
239277
dict_of_named_arrays[name] = ary
240278
return ary
241279

242280
rec_keyed_map_array_container(_as_dict_of_named_arrays,
243-
outputs)
281+
output_template)
244282

245-
import loopy as lp
283+
from pytato import DictOfNamedArrays
284+
compiled_func = self._dag_to_compiled_func(
285+
DictOfNamedArrays(dict_of_named_arrays),
286+
input_id_to_name_in_program=input_id_to_name_in_program,
287+
output_id_to_name_in_program=output_id_to_name_in_program,
288+
output_template=output_template)
246289

247-
with ProcessLogger(logger, "transform_dag"):
248-
pt_dict_of_named_arrays = self.actx.transform_dag(
249-
pt.make_dict_of_named_arrays(dict_of_named_arrays))
290+
self.program_cache[arg_id_to_descr] = compiled_func
291+
return compiled_func(arg_id_to_arg)
250292

251-
with ProcessLogger(logger, "generate_loopy"):
252-
pytato_program = pt.generate_loopy(pt_dict_of_named_arrays,
253-
options=lp.Options(
254-
return_dict=True,
255-
no_numpy=True),
256-
cl_device=self.actx.queue.device)
257-
assert isinstance(pytato_program, BoundPyOpenCLProgram)
258293

259-
with ProcessLogger(logger, "transform_loopy_program"):
294+
def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
295+
input_kwargs_for_loopy = {}
260296

261-
pytato_program = (pytato_program
262-
.with_transformed_program(
263-
lambda x: x.with_kernel(
264-
x.default_entrypoint
265-
.tagged(FromArrayContextCompile()))))
266-
267-
pytato_program = (pytato_program
268-
.with_transformed_program(self
269-
.actx
270-
.transform_loopy_program))
297+
for arg_id, arg in arg_id_to_arg.items():
298+
if np.isscalar(arg):
299+
arg = cla.to_device(actx.queue, np.array(arg))
300+
elif isinstance(arg, pt.array.DataWrapper):
301+
# got a Datwwrapper => simply gets its data
302+
arg = arg.data
303+
elif isinstance(arg, cla.Array):
304+
# got a frozen array => do nothing
305+
pass
306+
elif isinstance(arg, pt.Array):
307+
# got an array expression => evaluate it
308+
arg = actx.freeze(arg).with_queue(actx.queue)
309+
else:
310+
raise NotImplementedError(type(arg))
271311

272-
self.program_cache[arg_id_to_descr] = CompiledFunction(
273-
self.actx, pytato_program,
274-
input_naming_map, output_naming_map,
275-
output_template=outputs)
312+
input_kwargs_for_loopy[input_id_to_name_in_program[arg_id]] = arg
276313

277-
return self.program_cache[arg_id_to_descr](arg_id_to_arg)
314+
return input_kwargs_for_loopy
278315

279316

280-
@dataclass
317+
@dataclass(frozen=True)
281318
class CompiledFunction:
282319
"""
283320
A callable which captures the :class:`pytato.target.BoundProgram` resulting
@@ -319,40 +356,18 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
319356
:attr:`CompiledFunction.input_id_to_name_in_program` for input id's
320357
representation.
321358
"""
322-
from arraycontext.container.traversal import rec_keyed_map_array_container
323-
324-
input_kwargs_to_loopy = {}
325-
326-
# {{{ preprocess args to get arguments (CL buffers) to be fed to the
327-
# loopy program
328-
329-
for arg_id, arg in arg_id_to_arg.items():
330-
if np.isscalar(arg):
331-
arg = cla.to_device(self.actx.queue, np.array(arg))
332-
elif isinstance(arg, pt.array.DataWrapper):
333-
# got a Datwwrapper => simply gets its data
334-
arg = arg.data
335-
elif isinstance(arg, cla.Array):
336-
# got a frozen array => do nothing
337-
pass
338-
elif isinstance(arg, pt.Array):
339-
# got an array expression => evaluate it
340-
arg = self.actx.freeze(arg).with_queue(self.actx.queue)
341-
else:
342-
raise NotImplementedError(type(arg))
343-
344-
input_kwargs_to_loopy[self.input_id_to_name_in_program[arg_id]] = arg
359+
input_kwargs_for_loopy = _args_to_cl_buffers(
360+
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
345361

346362
evt, out_dict = self.pytato_program(queue=self.actx.queue,
347363
allocator=self.actx.allocator,
348-
**input_kwargs_to_loopy)
364+
**input_kwargs_for_loopy)
365+
349366
# FIXME Kernels (for now) allocate tons of memory in temporaries. If we
350367
# race too far ahead with enqueuing, there is a distinct risk of
351368
# running out of memory. This mitigates that risk a bit, for now.
352369
evt.wait()
353370

354-
# }}}
355-
356371
def to_output_template(keys, _):
357372
return self.actx.thaw(out_dict[self.output_id_to_name_in_program[keys]])
358373

0 commit comments

Comments
 (0)