Skip to content

Commit 30ed00b

Browse files
committed
Derive BatchedEinsumArrayContext/FusionArrayContext from arraycontext's implementations
1 parent f0962ca commit 30ed00b

File tree

1 file changed

+79
-0
lines changed

1 file changed

+79
-0
lines changed

meshmode/array_context.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,13 @@
3535
_PytestPyOpenCLArrayContextFactoryWithClass,
3636
_PytestPytatoPyOpenCLArrayContextFactory, register_pytest_array_context_factory)
3737

38+
from meshmode.arraycontext_extras.batched_einsum import (
39+
BatchedEinsumPytatoPyOpenCLArrayContext as BaseBatchedEinsumPytatoPyOpenCLArrayContext) # noqa: E501
40+
from meshmode.arraycontext_extras.split_actx import SplitPytatoPyOpenCLArrayContext
41+
42+
from pytools.tag import Tag
43+
from typing import Optional, Callable, Any
44+
3845

3946
def thaw(actx, ary):
4047
warn("meshmode.array_context.thaw is deprecated. Use arraycontext.thaw instead. "
@@ -346,4 +353,76 @@ def _import_names():
346353
# }}}
347354

348355

356+
def _fused_loop_name_prefix_getter(tag: Tag) -> str:
357+
from meshmode.transform_metadata import (
358+
DiscretizationElementAxisTag,
359+
DiscretizationFaceAxisTag,
360+
DiscretizationDOFAxisTag,
361+
DiscretizationAmbientDimAxisTag,
362+
DiscretizationTopologicalDimAxisTag,
363+
DiscretizationFlattenedDOFAxisTag,
364+
)
365+
if isinstance(tag, DiscretizationElementAxisTag):
366+
return "iel"
367+
elif isinstance(tag, DiscretizationFaceAxisTag):
368+
return "iface"
369+
elif isinstance(tag, DiscretizationDOFAxisTag):
370+
return "idof"
371+
elif isinstance(tag, DiscretizationAmbientDimAxisTag):
372+
return "iambient_dim"
373+
elif isinstance(tag, DiscretizationTopologicalDimAxisTag):
374+
return "itopo_dim"
375+
elif isinstance(tag, DiscretizationTopologicalDimAxisTag):
376+
return "itopo_dim"
377+
elif isinstance(tag, DiscretizationFlattenedDOFAxisTag):
378+
return "iflatted_dofs"
379+
else:
380+
raise NotImplementedError(type(tag))
381+
382+
383+
class BatchedEinsumPytatoPyOpenCLArrayContext(
384+
BaseBatchedEinsumPytatoPyOpenCLArrayContext):
385+
def __init__(
386+
self,
387+
queue, allocator=None,
388+
*,
389+
fallback_to_no_fusion: bool = True,
390+
compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None,
391+
feinsum_db: Optional[str] = None,
392+
log_loopy_statistics: bool = False,
393+
) -> None:
394+
from meshmode.transform_metadata import DiscretizationEntityAxisTag
395+
import feinsum as fnsm
396+
397+
super().__init__(
398+
queue, allocator,
399+
loop_fusion_axis_tag_t=DiscretizationEntityAxisTag,
400+
fallback_to_no_fusion=fallback_to_no_fusion,
401+
assume_all_indirection_maps_as_non_negative=True,
402+
compile_trace_callback=compile_trace_callback,
403+
feinsum_db=fnsm.DEFAULT_DB,
404+
log_loopy_statistics=log_loopy_statistics,
405+
fused_loop_name_prefix_getter=_fused_loop_name_prefix_getter
406+
)
407+
408+
409+
class FusionContractorArrayContext(BatchedEinsumPytatoPyOpenCLArrayContext):
410+
def __init__(self, *args, **kwargs):
411+
from warnings import warn
412+
warn("FusionContractorArrayContext is deprecated, use"
413+
" 'BatchedEinsumPytatoPyOpenCLArraYContext' instead."
414+
" This will be an error from June, 2023.",
415+
DeprecationWarning, stacklevel=2)
416+
super().__init__(*args, **kwargs)
417+
418+
419+
class SingleGridWorkBalancingPytatoArrayContext(SplitPytatoPyOpenCLArrayContext):
420+
def __init__(self, *args, **kwargs):
421+
from warnings import warn
422+
warn("SingleGridWorkBalancingPytatoArrayContext is deprecated, use"
423+
" 'SplitPytatoPyOpenCLArrayContext' instead. This will be an"
424+
" error from June, 2023.",
425+
DeprecationWarning, stacklevel=2)
426+
super().__init__(*args, **kwargs)
427+
349428
# vim: foldmethod=marker

0 commit comments

Comments
 (0)