Skip to content

Commit 221e583

Browse files
committed
Derive BatchedEinsumArrayContext/FusionArrayContext from arraycontext's implementations
1 parent f26cfe3 commit 221e583

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

meshmode/array_context.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,13 @@
3434
_PytestPytatoPyOpenCLArrayContextFactory,
3535
register_pytest_array_context_factory)
3636

37+
from meshmode.arraycontext_extras.batched_einsum import (
38+
BatchedEinsumPytatoPyOpenCLArrayContext as BaseBatchedEinsumPytatoPyOpenCLArrayContext) # noqa: E501
39+
from meshmode.arraycontext_extras.split_actx import SplitPytatoPyOpenCLArrayContext
40+
41+
from pytools.tag import Tag
42+
from typing import Optional, Callable, Any
43+
3744

3845
def thaw(actx, ary):
3946
warn("meshmode.array_context.thaw is deprecated. Use arraycontext.thaw instead. "
@@ -345,4 +352,76 @@ def _import_names():
345352
# }}}
346353

347354

355+
def _fused_loop_name_prefix_getter(tag: Tag) -> str:
356+
from meshmode.transform_metadata import (
357+
DiscretizationElementAxisTag,
358+
DiscretizationFaceAxisTag,
359+
DiscretizationDOFAxisTag,
360+
DiscretizationAmbientDimAxisTag,
361+
DiscretizationTopologicalDimAxisTag,
362+
DiscretizationFlattenedDOFAxisTag,
363+
)
364+
if isinstance(tag, DiscretizationElementAxisTag):
365+
return "iel"
366+
elif isinstance(tag, DiscretizationFaceAxisTag):
367+
return "iface"
368+
elif isinstance(tag, DiscretizationDOFAxisTag):
369+
return "idof"
370+
elif isinstance(tag, DiscretizationAmbientDimAxisTag):
371+
return "iambient_dim"
372+
elif isinstance(tag, DiscretizationTopologicalDimAxisTag):
373+
return "itopo_dim"
374+
elif isinstance(tag, DiscretizationTopologicalDimAxisTag):
375+
return "itopo_dim"
376+
elif isinstance(tag, DiscretizationFlattenedDOFAxisTag):
377+
return "iflatted_dofs"
378+
else:
379+
raise NotImplementedError(type(tag))
380+
381+
382+
class BatchedEinsumPytatoPyOpenCLArrayContext(
383+
BaseBatchedEinsumPytatoPyOpenCLArrayContext):
384+
def __init__(
385+
self,
386+
queue, allocator=None,
387+
*,
388+
fallback_to_no_fusion: bool = True,
389+
compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None,
390+
feinsum_db: Optional[str] = None,
391+
log_loopy_statistics: bool = False,
392+
) -> None:
393+
from meshmode.transform_metadata import DiscretizationEntityAxisTag
394+
import feinsum as fnsm
395+
396+
super().__init__(
397+
queue, allocator,
398+
loop_fusion_axis_tag_t=DiscretizationEntityAxisTag,
399+
fallback_to_no_fusion=fallback_to_no_fusion,
400+
assume_all_indirection_maps_as_non_negative=True,
401+
compile_trace_callback=compile_trace_callback,
402+
feinsum_db=fnsm.DEFAULT_DB,
403+
log_loopy_statistics=log_loopy_statistics,
404+
fused_loop_name_prefix_getter=_fused_loop_name_prefix_getter
405+
)
406+
407+
408+
class FusionContractorArrayContext(BatchedEinsumPytatoPyOpenCLArrayContext):
409+
def __init__(self, *args, **kwargs):
410+
from warnings import warn
411+
warn("FusionContractorArrayContext is deprecated, use"
412+
" 'BatchedEinsumPytatoPyOpenCLArraYContext' instead."
413+
" This will be an error from June, 2023.",
414+
DeprecationWarning, stacklevel=2)
415+
super().__init__(*args, **kwargs)
416+
417+
418+
class SingleGridWorkBalancingPytatoArrayContext(SplitPytatoPyOpenCLArrayContext):
419+
def __init__(self, *args, **kwargs):
420+
from warnings import warn
421+
warn("SingleGridWorkBalancingPytatoArrayContext is deprecated, use"
422+
" 'SplitPytatoPyOpenCLArrayContext' instead. This will be an"
423+
" error from June, 2023.",
424+
DeprecationWarning, stacklevel=2)
425+
super().__init__(*args, **kwargs)
426+
348427
# vim: foldmethod=marker

0 commit comments

Comments
 (0)