|
2 | 2 | .. currentmodule:: arraycontext.impl.pytato.compile |
3 | 3 | .. autoclass:: LazilyCompilingFunctionCaller |
4 | 4 | .. autoclass:: CompiledFunction |
| 5 | +.. autoclass:: FromArrayContextCompile |
5 | 6 | """ |
6 | 7 | __copyright__ = """ |
7 | 8 | Copyright (C) 2020-1 University of Illinois Board of Trustees |
|
40 | 41 | import pyopencl.array as cla |
41 | 42 | import pytato as pt |
42 | 43 | import itertools |
| 44 | +from pytools.tag import Tag |
43 | 45 |
|
44 | 46 | from pytools import ProcessLogger |
45 | 47 |
|
46 | 48 | import logging |
47 | 49 | logger = logging.getLogger(__name__) |
48 | 50 |
|
49 | 51 |
|
| 52 | +class FromArrayContextCompile(Tag): |
| 53 | + """ |
| 54 | + Tagged to the entrypoint kernel of every translation unit that is generated |
| 55 | + by :meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`. |
| 56 | +
|
| 57 | + Typically this tag serves as a branch condition in implementing a |
| 58 | + specialized transform strategy for kernels compiled by |
| 59 | + :meth:`~arraycontext.PytatoPyOpenCLArrayContext.compile`. |
| 60 | + """ |
| 61 | + |
| 62 | + |
50 | 63 | # {{{ helper classes: AbstractInputDescriptor |
51 | 64 |
|
52 | 65 | class AbstractInputDescriptor: |
@@ -245,6 +258,13 @@ def _as_dict_of_named_arrays(keys, ary): |
245 | 258 | assert isinstance(pytato_program, BoundPyOpenCLProgram) |
246 | 259 |
|
247 | 260 | with ProcessLogger(logger, "transform_loopy_program"): |
| 261 | + |
| 262 | + pytato_program = (pytato_program |
| 263 | + .with_transformed_program( |
| 264 | + lambda x: x.with_kernel( |
| 265 | + x.default_entrypoint |
| 266 | + .tagged(FromArrayContextCompile())))) |
| 267 | + |
248 | 268 | pytato_program = (pytato_program |
249 | 269 | .with_transformed_program(self |
250 | 270 | .actx |
|
0 commit comments