|
34 | 34 | _PytestPytatoPyOpenCLArrayContextFactory, |
35 | 35 | register_pytest_array_context_factory) |
36 | 36 |
|
| 37 | +from meshmode.arraycontext_extras.batched_einsum import ( |
| 38 | + BatchedEinsumPytatoPyOpenCLArrayContext as BaseBatchedEinsumPytatoPyOpenCLArrayContext) # noqa: E501 |
| 39 | +from meshmode.arraycontext_extras.split_actx import SplitPytatoPyOpenCLArrayContext |
| 40 | + |
| 41 | +from pytools.tag import Tag |
| 42 | +from typing import Optional, Callable, Any |
| 43 | + |
37 | 44 |
|
38 | 45 | def thaw(actx, ary): |
39 | 46 | warn("meshmode.array_context.thaw is deprecated. Use arraycontext.thaw instead. " |
@@ -345,4 +352,76 @@ def _import_names(): |
345 | 352 | # }}} |
346 | 353 |
|
347 | 354 |
|
| 355 | +def _fused_loop_name_prefix_getter(tag: Tag) -> str: |
| 356 | + from meshmode.transform_metadata import ( |
| 357 | + DiscretizationElementAxisTag, |
| 358 | + DiscretizationFaceAxisTag, |
| 359 | + DiscretizationDOFAxisTag, |
| 360 | + DiscretizationAmbientDimAxisTag, |
| 361 | + DiscretizationTopologicalDimAxisTag, |
| 362 | + DiscretizationFlattenedDOFAxisTag, |
| 363 | + ) |
| 364 | + if isinstance(tag, DiscretizationElementAxisTag): |
| 365 | + return "iel" |
| 366 | + elif isinstance(tag, DiscretizationFaceAxisTag): |
| 367 | + return "iface" |
| 368 | + elif isinstance(tag, DiscretizationDOFAxisTag): |
| 369 | + return "idof" |
| 370 | + elif isinstance(tag, DiscretizationAmbientDimAxisTag): |
| 371 | + return "iambient_dim" |
| 372 | + elif isinstance(tag, DiscretizationTopologicalDimAxisTag): |
| 373 | + return "itopo_dim" |
| 374 | + elif isinstance(tag, DiscretizationTopologicalDimAxisTag): |
| 375 | + return "itopo_dim" |
| 376 | + elif isinstance(tag, DiscretizationFlattenedDOFAxisTag): |
| 377 | + return "iflatted_dofs" |
| 378 | + else: |
| 379 | + raise NotImplementedError(type(tag)) |
| 380 | + |
| 381 | + |
| 382 | +class BatchedEinsumPytatoPyOpenCLArrayContext( |
| 383 | + BaseBatchedEinsumPytatoPyOpenCLArrayContext): |
| 384 | + def __init__( |
| 385 | + self, |
| 386 | + queue, allocator=None, |
| 387 | + *, |
| 388 | + fallback_to_no_fusion: bool = True, |
| 389 | + compile_trace_callback: Optional[Callable[[Any, str, Any], None]] = None, |
| 390 | + feinsum_db: Optional[str] = None, |
| 391 | + log_loopy_statistics: bool = False, |
| 392 | + ) -> None: |
| 393 | + from meshmode.transform_metadata import DiscretizationEntityAxisTag |
| 394 | + import feinsum as fnsm |
| 395 | + |
| 396 | + super().__init__( |
| 397 | + queue, allocator, |
| 398 | + loop_fusion_axis_tag_t=DiscretizationEntityAxisTag, |
| 399 | + fallback_to_no_fusion=fallback_to_no_fusion, |
| 400 | + assume_all_indirection_maps_as_non_negative=True, |
| 401 | + compile_trace_callback=compile_trace_callback, |
| 402 | + feinsum_db=fnsm.DEFAULT_DB, |
| 403 | + log_loopy_statistics=log_loopy_statistics, |
| 404 | + fused_loop_name_prefix_getter=_fused_loop_name_prefix_getter |
| 405 | + ) |
| 406 | + |
| 407 | + |
| 408 | +class FusionContractorArrayContext(BatchedEinsumPytatoPyOpenCLArrayContext): |
| 409 | + def __init__(self, *args, **kwargs): |
| 410 | + from warnings import warn |
| 411 | + warn("FusionContractorArrayContext is deprecated, use" |
| 412 | + " 'BatchedEinsumPytatoPyOpenCLArraYContext' instead." |
| 413 | + " This will be an error from June, 2023.", |
| 414 | + DeprecationWarning, stacklevel=2) |
| 415 | + super().__init__(*args, **kwargs) |
| 416 | + |
| 417 | + |
| 418 | +class SingleGridWorkBalancingPytatoArrayContext(SplitPytatoPyOpenCLArrayContext): |
| 419 | + def __init__(self, *args, **kwargs): |
| 420 | + from warnings import warn |
| 421 | + warn("SingleGridWorkBalancingPytatoArrayContext is deprecated, use" |
| 422 | + " 'SplitPytatoPyOpenCLArrayContext' instead. This will be an" |
| 423 | + " error from June, 2023.", |
| 424 | + DeprecationWarning, stacklevel=2) |
| 425 | + super().__init__(*args, **kwargs) |
| 426 | + |
348 | 427 | # vim: foldmethod=marker |
0 commit comments