3232from arraycontext import PytatoPyOpenCLArrayContext
3333from arraycontext .container .traversal import rec_keyed_map_array_container
3434
35+ import abc
3536import numpy as np
3637from typing import Any , Callable , Tuple , Dict , Mapping
3738from dataclasses import dataclass , field
@@ -81,7 +82,7 @@ class ScalarInputDescriptor(AbstractInputDescriptor):
8182@dataclass (frozen = True , eq = True )
8283class LeafArrayDescriptor (AbstractInputDescriptor ):
8384 dtype : np .dtype
84- shape : Tuple [ int , ...]
85+ shape : pt . array . ShapeType
8586
8687# }}}
8788
@@ -140,9 +141,14 @@ def id_collector(keys, ary):
140141 return ary
141142
142143 rec_keyed_map_array_container (id_collector , arg )
144+ elif isinstance (arg , pt .Array ):
145+ arg_id = (kw ,)
146+ arg_id_to_arg [arg_id ] = arg
147+ arg_id_to_descr [arg_id ] = LeafArrayDescriptor (np .dtype (arg .dtype ),
148+ arg .shape )
143149 else :
144150 raise ValueError ("Argument to a compiled operator should be"
145- " either a scalar or an array container. Got"
151+ " either a scalar, pt.Array or an array container. Got"
146152 f" '{ arg } '." )
147153
148154 return pmap (arg_id_to_arg ), pmap (arg_id_to_descr )
@@ -157,6 +163,9 @@ def _get_f_placeholder_args(arg, kw, arg_id_to_name):
157163 if np .isscalar (arg ):
158164 name = arg_id_to_name [(kw ,)]
159165 return pt .make_placeholder (name , (), np .dtype (type (arg )))
166+ elif isinstance (arg , pt .Array ):
167+ name = arg_id_to_name [(kw ,)]
168+ return pt .make_placeholder (name , arg .shape , arg .dtype )
160169 elif is_array_container_type (arg .__class__ ):
161170 def _rec_to_placeholder (keys , ary ):
162171 name = arg_id_to_name [(kw ,) + keys ]
@@ -218,16 +227,28 @@ def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays):
218227
219228 return pytato_program
220229
221- def _dag_to_compiled_func (self , dict_of_named_arrays ,
230+ def _dag_to_compiled_func (self , ary_or_dict_of_named_arrays ,
222231 input_id_to_name_in_program , output_id_to_name_in_program ,
223232 output_template ):
224- pytato_program = self ._dag_to_transformed_loopy_prg (dict_of_named_arrays )
225-
226- return CompiledFunction (
233+ if isinstance (ary_or_dict_of_named_arrays , pt .Array ):
234+ output_id = "_pt_out"
235+ dict_of_named_arrays = pt .make_dict_of_named_arrays (
236+ {output_id : ary_or_dict_of_named_arrays })
237+ pytato_program = self ._dag_to_transformed_loopy_prg (dict_of_named_arrays )
238+ return CompiledFunctionReturningArray (
227239 self .actx , pytato_program ,
228240 input_id_to_name_in_program = input_id_to_name_in_program ,
229- output_id_to_name_in_program = output_id_to_name_in_program ,
230- output_template = output_template )
241+ output_name_in_program = output_id )
242+ elif isinstance (ary_or_dict_of_named_arrays , pt .DictOfNamedArrays ):
243+ pytato_program = self ._dag_to_transformed_loopy_prg (
244+ ary_or_dict_of_named_arrays )
245+ return CompiledFunctionReturningArrayContainer (
246+ self .actx , pytato_program ,
247+ input_id_to_name_in_program = input_id_to_name_in_program ,
248+ output_id_to_name_in_program = output_id_to_name_in_program ,
249+ output_template = output_template )
250+ else :
251+ raise NotImplementedError (type (ary_or_dict_of_named_arrays ))
231252
232253 def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
233254 """
@@ -261,13 +282,14 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
261282 ** {kw : _get_f_placeholder_args (arg , kw , input_id_to_name_in_program )
262283 for kw , arg in kwargs .items ()})
263284
264- if not is_array_container_type (output_template .__class__ ):
285+ if (not (is_array_container_type (output_template .__class__ )
286+ or isinstance (output_template , pt .Array ))):
265287 # TODO: We could possibly just short-circuit this interface if the
266288 # returned type is a scalar. Not sure if it's worth it though.
267289 raise NotImplementedError (
268290 f"Function '{ self .f .__name__ } ' to be compiled "
269- "did not return an array container, but an instance of "
270- f"'{ output_template .__class__ } ' instead." )
291+ "did not return an array container or pt.Array, "
292+ f" but an instance of '{ output_template .__class__ } ' instead." )
271293
272294 def _as_dict_of_named_arrays (keys , ary ):
273295 name = "_pt_out_" + "_" .join (str (key )
@@ -312,8 +334,7 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
312334 return input_kwargs_for_loopy
313335
314336
315- @dataclass (frozen = True )
316- class CompiledFunction :
337+ class CompiledFunction (abc .ABC ):
317338 """
318339 A callable which captures the :class:`pytato.target.BoundProgram` resulting
319340 from calling :attr:`~LazilyCompilingFunctionCaller.f` with a given set of
@@ -328,6 +349,23 @@ class CompiledFunction:
328349 position of :attr:`~LazilyCompilingFunctionCaller.f`'s argument augmented
329350 with the leaf array's key if the argument is an array container.
330351
352+
353+ .. automethod:: __call__
354+ """
355+
356+ @abc .abstractmethod
357+ def __call__ (self , arg_id_to_arg ) -> Any :
358+ """
359+ :arg arg_id_to_arg: Mapping from input id to the passed argument. See
360+ :attr:`CompiledFunction.input_id_to_name_in_program` for input id's
361+ representation.
362+ """
363+ pass
364+
365+
366+ @dataclass (frozen = True )
367+ class CompiledFunctionReturningArrayContainer (CompiledFunction ):
368+ """
331369 .. attribute:: output_id_to_name_in_program
332370
333371 A mapping from output id to the name of
@@ -341,19 +379,13 @@ class CompiledFunction:
341379 An instance of :class:`arraycontext.ArrayContainer` that is the return
342380 type of the callable.
343381 """
344-
345382 actx : PytatoPyOpenCLArrayContext
346383 pytato_program : pt .target .BoundProgram
347384 input_id_to_name_in_program : Mapping [Tuple [Any , ...], str ]
348385 output_id_to_name_in_program : Mapping [Tuple [Any , ...], str ]
349386 output_template : ArrayContainer
350387
351388 def __call__ (self , arg_id_to_arg ) -> ArrayContainer :
352- """
353- :arg arg_id_to_arg: Mapping from input id to the passed argument. See
354- :attr:`CompiledFunction.input_id_to_name_in_program` for input id's
355- representation.
356- """
357389 input_kwargs_for_loopy = _args_to_cl_buffers (
358390 self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
359391
@@ -371,3 +403,31 @@ def to_output_template(keys, _):
371403
372404 return rec_keyed_map_array_container (to_output_template ,
373405 self .output_template )
406+
407+
408+ @dataclass (frozen = True )
409+ class CompiledFunctionReturningArray (CompiledFunction ):
410+ """
411+ .. attribute:: output_name_in_program
412+
413+ Name of the output array in the program.
414+ """
415+ actx : PytatoPyOpenCLArrayContext
416+ pytato_program : pt .target .BoundProgram
417+ input_id_to_name_in_program : Mapping [Tuple [Any , ...], str ]
418+ output_name : str
419+
420+ def __call__ (self , arg_id_to_arg ) -> ArrayContainer :
421+ input_kwargs_for_loopy = _args_to_cl_buffers (
422+ self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
423+
424+ evt , out_dict = self .pytato_program (queue = self .actx .queue ,
425+ allocator = self .actx .allocator ,
426+ ** input_kwargs_for_loopy )
427+
428+ # FIXME Kernels (for now) allocate tons of memory in temporaries. If we
429+ # race too far ahead with enqueuing, there is a distinct risk of
430+ # running out of memory. This mitigates that risk a bit, for now.
431+ evt .wait ()
432+
433+ return self .actx .thaw (out_dict [self .output_name ])
0 commit comments