Skip to content

Commit 13eed85

Browse files
kaushikcfdinducer
authored andcommitted
introduces FromArrayContextCompile tag
1 parent 6eaf518 commit 13eed85

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

arraycontext/impl/pytato/compile.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
.. currentmodule:: arraycontext.impl.pytato.compile
33
.. autoclass:: LazilyCompilingFunctionCaller
44
.. autoclass:: CompiledFunction
5+
.. autoclass:: FromArrayContextCompile
56
"""
67
__copyright__ = """
78
Copyright (C) 2020-1 University of Illinois Board of Trustees
@@ -40,13 +41,25 @@
4041
import pyopencl.array as cla
4142
import pytato as pt
4243
import itertools
44+
from pytools.tag import Tag
4345

4446
from pytools import ProcessLogger
4547

4648
import logging
4749
logger = logging.getLogger(__name__)
4850

4951

52+
class FromArrayContextCompile(Tag):
53+
"""
54+
Tagged to the entrypoint kernel of every translation unit that is generated
55+
by :meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`.
56+
57+
Typically this tag serves as a branch condition in implementing a
58+
specialized transform strategy for kernels compiled by
59+
:meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`.
60+
"""
61+
62+
5063
# {{{ helper classes: AbstractInputDescriptor
5164

5265
class AbstractInputDescriptor:
@@ -245,6 +258,13 @@ def _as_dict_of_named_arrays(keys, ary):
245258
assert isinstance(pytato_program, BoundPyOpenCLProgram)
246259

247260
with ProcessLogger(logger, "transform_loopy_program"):
261+
262+
pytato_program = (pytato_program
263+
.with_transformed_program(
264+
lambda x: x.with_kernel(
265+
x.default_entrypoint
266+
.tagged(FromArrayContextCompile()))))
267+
248268
pytato_program = (pytato_program
249269
.with_transformed_program(self
250270
.actx

0 commit comments

Comments
 (0)