Skip to content

Commit aaf1607

Browse files
committed
Implement PytatoSplitArrayContext
1 parent 50511fe commit aaf1607

File tree

3 files changed

+694
-0
lines changed

3 files changed

+694
-0
lines changed

arraycontext/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from .impl.jax import EagerJAXArrayContext
5454
from .impl.pyopencl import PyOpenCLArrayContext
5555
from .impl.pytato import PytatoJAXArrayContext, PytatoPyOpenCLArrayContext
56+
from .impl.pytato.split_actx import SplitPytatoPyOpenCLArrayContext
5657
from .loopy import make_loopy_program
5758
# deprecated, remove in 2022.
5859
from .metadata import _FirstAxisIsElementsTag
@@ -98,6 +99,8 @@
9899
"outer",
99100

100101
"PyOpenCLArrayContext", "PytatoPyOpenCLArrayContext",
102+
"SplitPytatoPyOpenCLArrayContext",
103+
101104
"PytatoJAXArrayContext",
102105
"EagerJAXArrayContext",
103106

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,143 @@
1+
"""
2+
.. autoclass:: SplitPytatoPyOpenCLArrayContext
3+
4+
"""
5+
6+
__copyright__ = """
7+
Copyright (C) 2023 Kaushik Kulkarni
8+
Copyright (C) 2023 Andreas Kloeckner
9+
Copyright (C) 2022 Matthias Diener
10+
Copyright (C) 2022 Matt Smith
11+
"""
12+
13+
__license__ = """
14+
Permission is hereby granted, free of charge, to any person obtaining a copy
15+
of this software and associated documentation files (the "Software"), to deal
16+
in the Software without restriction, including without limitation the rights
17+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
18+
copies of the Software, and to permit persons to whom the Software is
19+
furnished to do so, subject to the following conditions:
20+
21+
The above copyright notice and this permission notice shall be included in
22+
all copies or substantial portions of the Software.
23+
24+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
25+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
26+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
27+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
28+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
29+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
30+
THE SOFTWARE.
31+
"""
32+
33+
import sys
34+
from typing import TYPE_CHECKING
35+
36+
import loopy as lp
37+
38+
from arraycontext.impl.pytato import PytatoPyOpenCLArrayContext
39+
40+
41+
if TYPE_CHECKING or getattr(sys, "_BUILDING_SPHINX_DOCS", False):
42+
import pytato
43+
44+
45+
class SplitPytatoPyOpenCLArrayContext(PytatoPyOpenCLArrayContext):
46+
"""
47+
.. note::
48+
49+
Refer to :meth:`transform_dag` and :meth:`transform_loopy_program` for
50+
details on the transformation algorithm provided by this array context.
51+
52+
.. warning::
53+
54+
For expression graphs with large number of nodes high compile times are
55+
expected.
56+
"""
57+
def transform_dag(self,
58+
dag: "pytato.DictOfNamedArrays") -> "pytato.DictOfNamedArrays":
59+
r"""
60+
Returns a transformed version of *dag*, where the applied transform is:
61+
62+
#. Materialize as per MPMS materialization heuristic.
63+
#. materialize every :class:`pytato.array.Einsum`\ 's inputs and outputs.
64+
"""
65+
import pytato as pt
66+
67+
# Step 1. Collapse equivalent nodes in DAG.
68+
# -----------------------------------------
69+
# type-ignore-reason: mypy is right pytato provides imprecise types.
70+
dag = pt.transform.deduplicate_data_wrappers(dag) # type: ignore[assignment]
71+
72+
# Step 2. Materialize reduction inputs/outputs.
73+
# ------------------------------------------
74+
from .utils import (
75+
get_inputs_and_outputs_of_einsum,
76+
get_inputs_and_outputs_of_reduction_nodes)
77+
78+
reduction_inputs_outputs = frozenset.union(
79+
*get_inputs_and_outputs_of_einsum(dag),
80+
*get_inputs_and_outputs_of_reduction_nodes(dag)
81+
)
82+
83+
def materialize_einsum(expr: pt.transform.ArrayOrNames
84+
) -> pt.transform.ArrayOrNames:
85+
if expr in reduction_inputs_outputs:
86+
if isinstance(expr, pt.InputArgumentBase):
87+
return expr
88+
else:
89+
return expr.tagged(pt.tags.ImplStored())
90+
else:
91+
return expr
92+
93+
# type-ignore-reason: mypy is right pytato provides imprecise types.
94+
dag = pt.transform.map_and_copy(dag, # type: ignore[assignment]
95+
materialize_einsum)
96+
97+
# Step 3. MPMS materialize
98+
# ------------------------
99+
dag = pt.transform.materialize_with_mpms(dag)
100+
101+
return dag
102+
103+
def transform_loopy_program(self,
104+
t_unit: lp.TranslationUnit) -> lp.TranslationUnit:
105+
r"""
106+
Returns a transformed version of *t_unit*, where the applied transform is:
107+
108+
#. An execution grid size :math:`G` is selected based on *self*'s
109+
OpenCL-device.
110+
#. The iteration domain for each statement in the *t_unit* is divided to
111+
equally among the work-items in :math:`G`.
112+
#. Kernel boundaries are drawn between every statement in the instruction.
113+
Although one can relax this constraint by letting :mod:`loopy` compute
114+
where to insert the global barriers, but it is not guaranteed to be
115+
performance profitable since we do not attempt any further loop-fusion
116+
and/or array contraction.
117+
#. Once the kernel boundaries are inferred, :func:`alias_global_temporaries`
118+
is invoked to reduce the memory peak memory used by the transformed
119+
program.
120+
"""
121+
# Step 1. Split the iteration across work-items
122+
# ---------------------------------------------
123+
from .utils import split_iteration_domain_across_work_items
124+
t_unit = split_iteration_domain_across_work_items(t_unit, self.queue.device)
125+
126+
# Step 2. Add a global barrier between individual loop nests.
127+
# ------------------------------------------------------
128+
from .utils import add_gbarrier_between_disjoint_loop_nests
129+
t_unit = add_gbarrier_between_disjoint_loop_nests(t_unit)
130+
131+
# Step 3. Transform reduce to scalar statements
132+
# ---------------------------------------------
133+
from .utils import parallelize_reduce_to_scalars
134+
t_unit = parallelize_reduce_to_scalars(t_unit, self.queue.device)
135+
136+
# Step 4. Alias global temporaries with disjoint live intervals
137+
# -------------------------------------------------------------
138+
from .utils import alias_global_temporaries
139+
t_unit = alias_global_temporaries(t_unit)
140+
141+
return t_unit
142+
143+
# vim: fdm=marker

0 commit comments

Comments
 (0)