@@ -105,6 +105,8 @@ class LeafArrayDescriptor(AbstractInputDescriptor):
105105# }}}
106106
107107
108+ # {{{ utilities
109+
108110def _ary_container_key_stringifier (keys : Tuple [Any , ...]) -> str :
109111 """
110112 Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an
@@ -236,6 +238,10 @@ def _rec_to_placeholder(keys, ary):
236238 else :
237239 raise NotImplementedError (type (arg ))
238240
241+ # }}}
242+
243+
244+ # {{{ BaseLazilyCompilingFunctionCaller
239245
240246@dataclass
241247class BaseLazilyCompilingFunctionCaller :
@@ -366,6 +372,10 @@ def _as_dict_of_named_arrays(keys, ary):
366372 self .program_cache [arg_id_to_descr ] = compiled_func
367373 return compiled_func (arg_id_to_arg )
368374
375+ # }}}
376+
377+
378+ # {{{ LazilyPyOpenCLCompilingFunctionCaller
369379
370380class LazilyPyOpenCLCompilingFunctionCaller (BaseLazilyCompilingFunctionCaller ):
371381 @property
@@ -440,6 +450,8 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
440450
441451 return pytato_program , name_in_program_to_tags , name_in_program_to_axes
442452
453+ # }}}
454+
443455
444456# {{{ preserve back compat
445457
@@ -461,6 +473,8 @@ def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays):
461473# }}}
462474
463475
476+ # {{{ LazilyJAXCompilingFunctionCaller
477+
464478class LazilyJAXCompilingFunctionCaller (BaseLazilyCompilingFunctionCaller ):
465479 @property
466480 def compiled_function_returning_array_container_class (
@@ -553,6 +567,10 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
553567 return _args_to_device_buffers (actx , input_id_to_name_in_program ,
554568 arg_id_to_arg )
555569
570+ # }}}
571+
572+
573+ # {{{ compiled function
556574
557575class CompiledFunction (abc .ABC ):
558576 """
@@ -582,6 +600,10 @@ def __call__(self, arg_id_to_arg) -> Any:
582600 """
583601 pass
584602
603+ # }}}
604+
605+
606+ # {{{ copmiled pyopencl function
585607
586608@dataclass (frozen = True )
587609class CompiledPyOpenCLFunctionReturningArrayContainer (CompiledFunction ):
@@ -670,7 +692,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
670692 self .output_axes ),
671693 tags = self .output_tags ))
672694
695+ # }}}
696+
673697
698+ # {{{ comiled jax function
674699@dataclass (frozen = True )
675700class CompiledJAXFunctionReturningArrayContainer (CompiledFunction ):
676701 """
@@ -732,3 +757,7 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
732757 evt , out_dict = self .pytato_program (** input_kwargs_for_loopy )
733758
734759 return self .actx .thaw (out_dict [self .output_name ])
760+
761+ # }}}
762+
763+ # vim: foldmethod=marker
0 commit comments