Skip to content

Commit 8dab9bc

Browse files
kaushikcfdinducer
authored andcommitted
PytatoPyOpenCLArrayContext.compile: support returning arrays
`compile` only supported compiling callables that returned array containers. Extends the logic to support compiling callables that simply return thawed arrays.
1 parent c014adb commit 8dab9bc

File tree

1 file changed

+79
-19
lines changed

1 file changed

+79
-19
lines changed

arraycontext/impl/pytato/compile.py

Lines changed: 79 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
from arraycontext import PytatoPyOpenCLArrayContext
3333
from arraycontext.container.traversal import rec_keyed_map_array_container
3434

35+
import abc
3536
import numpy as np
3637
from typing import Any, Callable, Tuple, Dict, Mapping
3738
from dataclasses import dataclass, field
@@ -81,7 +82,7 @@ class ScalarInputDescriptor(AbstractInputDescriptor):
8182
@dataclass(frozen=True, eq=True)
8283
class LeafArrayDescriptor(AbstractInputDescriptor):
8384
dtype: np.dtype
84-
shape: Tuple[int, ...]
85+
shape: pt.array.ShapeType
8586

8687
# }}}
8788

@@ -140,9 +141,14 @@ def id_collector(keys, ary):
140141
return ary
141142

142143
rec_keyed_map_array_container(id_collector, arg)
144+
elif isinstance(arg, pt.Array):
145+
arg_id = (kw,)
146+
arg_id_to_arg[arg_id] = arg
147+
arg_id_to_descr[arg_id] = LeafArrayDescriptor(np.dtype(arg.dtype),
148+
arg.shape)
143149
else:
144150
raise ValueError("Argument to a compiled operator should be"
145-
" either a scalar or an array container. Got"
151+
" either a scalar, pt.Array or an array container. Got"
146152
f" '{arg}'.")
147153

148154
return pmap(arg_id_to_arg), pmap(arg_id_to_descr)
@@ -157,6 +163,9 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name):
157163
if np.isscalar(arg):
158164
name = arg_id_to_name[(kw,)]
159165
return pt.make_placeholder(name, (), np.dtype(type(arg)))
166+
elif isinstance(arg, pt.Array):
167+
name = arg_id_to_name[(kw,)]
168+
return pt.make_placeholder(name, arg.shape, arg.dtype)
160169
elif is_array_container_type(arg.__class__):
161170
def _rec_to_placeholder(keys, ary):
162171
name = arg_id_to_name[(kw,) + keys]
@@ -218,16 +227,28 @@ def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays):
218227

219228
return pytato_program
220229

221-
def _dag_to_compiled_func(self, dict_of_named_arrays,
230+
def _dag_to_compiled_func(self, ary_or_dict_of_named_arrays,
222231
input_id_to_name_in_program, output_id_to_name_in_program,
223232
output_template):
224-
pytato_program = self._dag_to_transformed_loopy_prg(dict_of_named_arrays)
225-
226-
return CompiledFunction(
233+
if isinstance(ary_or_dict_of_named_arrays, pt.Array):
234+
output_id = "_pt_out"
235+
dict_of_named_arrays = pt.make_dict_of_named_arrays(
236+
{output_id: ary_or_dict_of_named_arrays})
237+
pytato_program = self._dag_to_transformed_loopy_prg(dict_of_named_arrays)
238+
return CompiledFunctionReturningArray(
227239
self.actx, pytato_program,
228240
input_id_to_name_in_program=input_id_to_name_in_program,
229-
output_id_to_name_in_program=output_id_to_name_in_program,
230-
output_template=output_template)
241+
output_name_in_program=output_id)
242+
elif isinstance(ary_or_dict_of_named_arrays, pt.DictOfNamedArrays):
243+
pytato_program = self._dag_to_transformed_loopy_prg(
244+
ary_or_dict_of_named_arrays)
245+
return CompiledFunctionReturningArrayContainer(
246+
self.actx, pytato_program,
247+
input_id_to_name_in_program=input_id_to_name_in_program,
248+
output_id_to_name_in_program=output_id_to_name_in_program,
249+
output_template=output_template)
250+
else:
251+
raise NotImplementedError(type(ary_or_dict_of_named_arrays))
231252

232253
def __call__(self, *args: Any, **kwargs: Any) -> Any:
233254
"""
@@ -261,13 +282,14 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
261282
**{kw: _get_f_placeholder_args(arg, kw, input_id_to_name_in_program)
262283
for kw, arg in kwargs.items()})
263284

264-
if not is_array_container_type(output_template.__class__):
285+
if (not (is_array_container_type(output_template.__class__)
286+
or isinstance(output_template, pt.Array))):
265287
# TODO: We could possibly just short-circuit this interface if the
266288
# returned type is a scalar. Not sure if it's worth it though.
267289
raise NotImplementedError(
268290
f"Function '{self.f.__name__}' to be compiled "
269-
"did not return an array container, but an instance of "
270-
f"'{output_template.__class__}' instead.")
291+
"did not return an array container or pt.Array,"
292+
f" but an instance of '{output_template.__class__}' instead.")
271293

272294
def _as_dict_of_named_arrays(keys, ary):
273295
name = "_pt_out_" + "_".join(str(key)
@@ -312,8 +334,7 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
312334
return input_kwargs_for_loopy
313335

314336

315-
@dataclass(frozen=True)
316-
class CompiledFunction:
337+
class CompiledFunction(abc.ABC):
317338
"""
318339
A callable which captures the :class:`pytato.target.BoundProgram` resulting
319340
from calling :attr:`~LazilyCompilingFunctionCaller.f` with a given set of
@@ -328,6 +349,23 @@ class CompiledFunction:
328349
position of :attr:`~LazilyCompilingFunctionCaller.f`'s argument augmented
329350
with the leaf array's key if the argument is an array container.
330351
352+
353+
.. automethod:: __call__
354+
"""
355+
356+
@abc.abstractmethod
357+
def __call__(self, arg_id_to_arg) -> Any:
358+
"""
359+
:arg arg_id_to_arg: Mapping from input id to the passed argument. See
360+
:attr:`CompiledFunction.input_id_to_name_in_program` for input id's
361+
representation.
362+
"""
363+
pass
364+
365+
366+
@dataclass(frozen=True)
367+
class CompiledFunctionReturningArrayContainer(CompiledFunction):
368+
"""
331369
.. attribute:: output_id_to_name_in_program
332370
333371
A mapping from output id to the name of
@@ -341,19 +379,13 @@ class CompiledFunction:
341379
An instance of :class:`arraycontext.ArrayContainer` that is the return
342380
type of the callable.
343381
"""
344-
345382
actx: PytatoPyOpenCLArrayContext
346383
pytato_program: pt.target.BoundProgram
347384
input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
348385
output_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
349386
output_template: ArrayContainer
350387

351388
def __call__(self, arg_id_to_arg) -> ArrayContainer:
352-
"""
353-
:arg arg_id_to_arg: Mapping from input id to the passed argument. See
354-
:attr:`CompiledFunction.input_id_to_name_in_program` for input id's
355-
representation.
356-
"""
357389
input_kwargs_for_loopy = _args_to_cl_buffers(
358390
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
359391

@@ -371,3 +403,31 @@ def to_output_template(keys, _):
371403

372404
return rec_keyed_map_array_container(to_output_template,
373405
self.output_template)
406+
407+
408+
@dataclass(frozen=True)
409+
class CompiledFunctionReturningArray(CompiledFunction):
410+
"""
411+
.. attribute:: output_name_in_program
412+
413+
Name of the output array in the program.
414+
"""
415+
actx: PytatoPyOpenCLArrayContext
416+
pytato_program: pt.target.BoundProgram
417+
input_id_to_name_in_program: Mapping[Tuple[Any, ...], str]
418+
output_name: str
419+
420+
def __call__(self, arg_id_to_arg) -> ArrayContainer:
421+
input_kwargs_for_loopy = _args_to_cl_buffers(
422+
self.actx, self.input_id_to_name_in_program, arg_id_to_arg)
423+
424+
evt, out_dict = self.pytato_program(queue=self.actx.queue,
425+
allocator=self.actx.allocator,
426+
**input_kwargs_for_loopy)
427+
428+
# FIXME Kernels (for now) allocate tons of memory in temporaries. If we
429+
# race too far ahead with enqueuing, there is a distinct risk of
430+
# running out of memory. This mitigates that risk a bit, for now.
431+
evt.wait()
432+
433+
return self.actx.thaw(out_dict[self.output_name])

0 commit comments

Comments
 (0)