@@ -187,6 +187,49 @@ class LazilyCompilingFunctionCaller:
187187 program_cache : Dict ["PMap[Tuple[Any, ...], AbstractInputDescriptor]" ,
188188 "CompiledFunction" ] = field (default_factory = lambda : {})
189189
190+ def _dag_to_transformed_loopy_prg (self , dict_of_named_arrays ):
191+ from pytato .target .loopy import BoundPyOpenCLProgram
192+
193+ import loopy as lp
194+
195+ with ProcessLogger (logger , "transform_dag" ):
196+ pt_dict_of_named_arrays = self .actx .transform_dag (
197+ pt .make_dict_of_named_arrays (dict_of_named_arrays ))
198+
199+ with ProcessLogger (logger , "generate_loopy" ):
200+ pytato_program = pt .generate_loopy (pt_dict_of_named_arrays ,
201+ options = lp .Options (
202+ return_dict = True ,
203+ no_numpy = True ),
204+ cl_device = self .actx .queue .device )
205+ assert isinstance (pytato_program , BoundPyOpenCLProgram )
206+
207+ with ProcessLogger (logger , "transform_loopy_program" ):
208+
209+ pytato_program = (pytato_program
210+ .with_transformed_program (
211+ lambda x : x .with_kernel (
212+ x .default_entrypoint
213+ .tagged (FromArrayContextCompile ()))))
214+
215+ pytato_program = (pytato_program
216+ .with_transformed_program (self
217+ .actx
218+ .transform_loopy_program ))
219+
220+ return pytato_program
221+
222+ def _dag_to_compiled_func (self , dict_of_named_arrays ,
223+ input_id_to_name_in_program , output_id_to_name_in_program ,
224+ output_template ):
225+ pytato_program = self ._dag_to_transformed_loopy_prg (dict_of_named_arrays )
226+
227+ return CompiledFunction (
228+ self .actx , pytato_program ,
229+ input_id_to_name_in_program = input_id_to_name_in_program ,
230+ output_id_to_name_in_program = output_id_to_name_in_program ,
231+ output_template = output_template )
232+
190233 def __call__ (self , * args : Any , ** kwargs : Any ) -> Any :
191234 """
192235 Returns the result of :attr:`~LazilyCompilingFunctionCaller.f`'s
@@ -197,8 +240,6 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
197240 :attr:`~LazilyCompilingFunctionCaller.f` with *args* in a lazy-sense.
198241 The intermediary pytato DAG for *args* is memoized in *self*.
199242 """
200- from pytato .target .loopy import BoundPyOpenCLProgram
201-
202243 arg_id_to_arg , arg_id_to_descr = _get_arg_id_to_arg_and_arg_id_to_descr (
203244 args , kwargs )
204245
@@ -210,74 +251,70 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any:
210251 return compiled_f (arg_id_to_arg )
211252
212253 dict_of_named_arrays = {}
213- # output_naming_map: result id to name of the named array in the
214- # generated pytato DAG.
215- output_naming_map = {}
216- # input_naming_map: argument id to placeholder name in the generated
217- # pytato DAG.
218- input_naming_map = {
254+ output_id_to_name_in_program = {}
255+ input_id_to_name_in_program = {
219256 arg_id : f"_actx_in_{ _ary_container_key_stringifier (arg_id )} "
220257 for arg_id in arg_id_to_arg }
221258
222- outputs = self .f (* [_get_f_placeholder_args (arg , iarg , input_naming_map )
223- for iarg , arg in enumerate (args )],
224- ** {kw : _get_f_placeholder_args (arg , kw , input_naming_map )
225- for kw , arg in kwargs .items ()})
259+ output_template = self .f (
260+ * [_get_f_placeholder_args (arg , iarg , input_id_to_name_in_program )
261+ for iarg , arg in enumerate (args )],
262+ ** {kw : _get_f_placeholder_args (arg , kw , input_id_to_name_in_program )
263+ for kw , arg in kwargs .items ()})
226264
227- if not is_array_container_type (outputs .__class__ ):
265+ if not is_array_container_type (output_template .__class__ ):
228266 # TODO: We could possibly just short-circuit this interface if the
229267 # returned type is a scalar. Not sure if it's worth it though.
230268 raise NotImplementedError (
231269 f"Function '{ self .f .__name__ } ' to be compiled "
232270 "did not return an array container, but an instance of "
233- f"'{ outputs .__class__ } ' instead." )
271+ f"'{ output_template .__class__ } ' instead." )
234272
235273 def _as_dict_of_named_arrays (keys , ary ):
236274 name = "_pt_out_" + "_" .join (str (key )
237275 for key in keys )
238- output_naming_map [keys ] = name
276+ output_id_to_name_in_program [keys ] = name
239277 dict_of_named_arrays [name ] = ary
240278 return ary
241279
242280 rec_keyed_map_array_container (_as_dict_of_named_arrays ,
243- outputs )
281+ output_template )
244282
245- import loopy as lp
283+ from pytato import DictOfNamedArrays
284+ compiled_func = self ._dag_to_compiled_func (
285+ DictOfNamedArrays (dict_of_named_arrays ),
286+ input_id_to_name_in_program = input_id_to_name_in_program ,
287+ output_id_to_name_in_program = output_id_to_name_in_program ,
288+ output_template = output_template )
246289
247- with ProcessLogger (logger , "transform_dag" ):
248- pt_dict_of_named_arrays = self .actx .transform_dag (
249- pt .make_dict_of_named_arrays (dict_of_named_arrays ))
290+ self .program_cache [arg_id_to_descr ] = compiled_func
291+ return compiled_func (arg_id_to_arg )
250292
251- with ProcessLogger (logger , "generate_loopy" ):
252- pytato_program = pt .generate_loopy (pt_dict_of_named_arrays ,
253- options = lp .Options (
254- return_dict = True ,
255- no_numpy = True ),
256- cl_device = self .actx .queue .device )
257- assert isinstance (pytato_program , BoundPyOpenCLProgram )
258293
259- with ProcessLogger (logger , "transform_loopy_program" ):
294+ def _args_to_cl_buffers (actx , input_id_to_name_in_program , arg_id_to_arg ):
295+ input_kwargs_for_loopy = {}
260296
261- pytato_program = (pytato_program
262- .with_transformed_program (
263- lambda x : x .with_kernel (
264- x .default_entrypoint
265- .tagged (FromArrayContextCompile ()))))
266-
267- pytato_program = (pytato_program
268- .with_transformed_program (self
269- .actx
270- .transform_loopy_program ))
297+ for arg_id , arg in arg_id_to_arg .items ():
298+ if np .isscalar (arg ):
299+ arg = cla .to_device (actx .queue , np .array (arg ))
300+ elif isinstance (arg , pt .array .DataWrapper ):
301+ # got a Datwwrapper => simply gets its data
302+ arg = arg .data
303+ elif isinstance (arg , cla .Array ):
304+ # got a frozen array => do nothing
305+ pass
306+ elif isinstance (arg , pt .Array ):
307+ # got an array expression => evaluate it
308+ arg = actx .freeze (arg ).with_queue (actx .queue )
309+ else :
310+ raise NotImplementedError (type (arg ))
271311
272- self .program_cache [arg_id_to_descr ] = CompiledFunction (
273- self .actx , pytato_program ,
274- input_naming_map , output_naming_map ,
275- output_template = outputs )
312+ input_kwargs_for_loopy [input_id_to_name_in_program [arg_id ]] = arg
276313
277- return self . program_cache [ arg_id_to_descr ]( arg_id_to_arg )
314+ return input_kwargs_for_loopy
278315
279316
280- @dataclass
317+ @dataclass ( frozen = True )
281318class CompiledFunction :
282319 """
283320 A callable which captures the :class:`pytato.target.BoundProgram` resulting
@@ -319,40 +356,18 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
319356 :attr:`CompiledFunction.input_id_to_name_in_program` for input id's
320357 representation.
321358 """
322- from arraycontext .container .traversal import rec_keyed_map_array_container
323-
324- input_kwargs_to_loopy = {}
325-
326- # {{{ preprocess args to get arguments (CL buffers) to be fed to the
327- # loopy program
328-
329- for arg_id , arg in arg_id_to_arg .items ():
330- if np .isscalar (arg ):
331- arg = cla .to_device (self .actx .queue , np .array (arg ))
332- elif isinstance (arg , pt .array .DataWrapper ):
333- # got a Datwwrapper => simply gets its data
334- arg = arg .data
335- elif isinstance (arg , cla .Array ):
336- # got a frozen array => do nothing
337- pass
338- elif isinstance (arg , pt .Array ):
339- # got an array expression => evaluate it
340- arg = self .actx .freeze (arg ).with_queue (self .actx .queue )
341- else :
342- raise NotImplementedError (type (arg ))
343-
344- input_kwargs_to_loopy [self .input_id_to_name_in_program [arg_id ]] = arg
359+ input_kwargs_for_loopy = _args_to_cl_buffers (
360+ self .actx , self .input_id_to_name_in_program , arg_id_to_arg )
345361
346362 evt , out_dict = self .pytato_program (queue = self .actx .queue ,
347363 allocator = self .actx .allocator ,
348- ** input_kwargs_to_loopy )
364+ ** input_kwargs_for_loopy )
365+
349366 # FIXME Kernels (for now) allocate tons of memory in temporaries. If we
350367 # race too far ahead with enqueuing, there is a distinct risk of
351368 # running out of memory. This mitigates that risk a bit, for now.
352369 evt .wait ()
353370
354- # }}}
355-
356371 def to_output_template (keys , _ ):
357372 return self .actx .thaw (out_dict [self .output_id_to_name_in_program [keys ]])
358373
0 commit comments