Skip to content

Commit 9ab0849

Browse files
committed
Add folding sections in arraycontext.impl.pytato.compile
1 parent e47d0cf commit 9ab0849

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

arraycontext/impl/pytato/compile.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,8 @@ class LeafArrayDescriptor(AbstractInputDescriptor):
105105
# }}}
106106

107107

108+
# {{{ utilities
109+
108110
def _ary_container_key_stringifier(keys: Tuple[Any, ...]) -> str:
109111
"""
110112
Helper for :meth:`BaseLazilyCompilingFunctionCaller.__call__`. Stringifies an
@@ -236,6 +238,10 @@ def _rec_to_placeholder(keys, ary):
236238
else:
237239
raise NotImplementedError(type(arg))
238240

241+
# }}}
242+
243+
244+
# {{{ BaseLazilyCompilingFunctionCaller
239245

240246
@dataclass
241247
class BaseLazilyCompilingFunctionCaller:
@@ -366,6 +372,10 @@ def _as_dict_of_named_arrays(keys, ary):
366372
self.program_cache[arg_id_to_descr] = compiled_func
367373
return compiled_func(arg_id_to_arg)
368374

375+
# }}}
376+
377+
378+
# {{{ LazilyPyOpenCLCompilingFunctionCaller
369379

370380
class LazilyPyOpenCLCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller):
371381
@property
@@ -440,6 +450,8 @@ def _dag_to_transformed_pytato_prg(self, dict_of_named_arrays, *, prg_id=None):
440450

441451
return pytato_program, name_in_program_to_tags, name_in_program_to_axes
442452

453+
# }}}
454+
443455

444456
# {{{ preserve back compat
445457

@@ -461,6 +473,8 @@ def _dag_to_transformed_loopy_prg(self, dict_of_named_arrays):
461473
# }}}
462474

463475

476+
# {{{ LazilyJAXCompilingFunctionCaller
477+
464478
class LazilyJAXCompilingFunctionCaller(BaseLazilyCompilingFunctionCaller):
465479
@property
466480
def compiled_function_returning_array_container_class(
@@ -553,6 +567,10 @@ def _args_to_cl_buffers(actx, input_id_to_name_in_program, arg_id_to_arg):
553567
return _args_to_device_buffers(actx, input_id_to_name_in_program,
554568
arg_id_to_arg)
555569

570+
# }}}
571+
572+
573+
# {{{ compiled function
556574

557575
class CompiledFunction(abc.ABC):
558576
"""
@@ -582,6 +600,10 @@ def __call__(self, arg_id_to_arg) -> Any:
582600
"""
583601
pass
584602

603+
# }}}
604+
605+
606+
# {{{ copmiled pyopencl function
585607

586608
@dataclass(frozen=True)
587609
class CompiledPyOpenCLFunctionReturningArrayContainer(CompiledFunction):
@@ -670,7 +692,10 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
670692
self.output_axes),
671693
tags=self.output_tags))
672694

695+
# }}}
696+
673697

698+
# {{{ comiled jax function
674699
@dataclass(frozen=True)
675700
class CompiledJAXFunctionReturningArrayContainer(CompiledFunction):
676701
"""
@@ -732,3 +757,7 @@ def __call__(self, arg_id_to_arg) -> ArrayContainer:
732757
evt, out_dict = self.pytato_program(**input_kwargs_for_loopy)
733758

734759
return self.actx.thaw(out_dict[self.output_name])
760+
761+
# }}}
762+
763+
# vim: foldmethod=marker

0 commit comments

Comments
 (0)