Skip to content

Commit bed32d0

Browse files
alexfiklinducer
authored andcommitted
feat(typing): add types to transform.privatize
1 parent 3b73bb8 commit bed32d0

File tree

1 file changed

+93
-57
lines changed

1 file changed

+93
-57
lines changed

loopy/transform/privatize.py

Lines changed: 93 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,26 @@
2323
THE SOFTWARE.
2424
"""
2525

26-
2726
import logging
27+
from typing import TYPE_CHECKING
28+
29+
from typing_extensions import override
2830

29-
import pymbolic
3031
import pymbolic.primitives as p
3132

3233
from loopy.diagnostic import LoopyError
3334
from loopy.translation_unit import for_each_kernel
3435

3536

37+
if TYPE_CHECKING:
38+
from collections.abc import Mapping, Sequence
39+
40+
from pymbolic.typing import ArithmeticExpression, Expression
41+
42+
from loopy.kernel import LoopKernel
43+
from loopy.kernel.instruction import InstructionBase
44+
from loopy.typing import InameStr, InameStrSet
45+
3646
logger = logging.getLogger(__name__)
3747

3848

@@ -51,13 +61,20 @@
5161

5262

5363
class ExtraInameIndexInserter(IdentityMapper[[]]):
54-
def __init__(self, var_to_new_inames, iname_to_lbound):
64+
var_to_new_inames: Mapping[str, Sequence[p.Variable]]
65+
iname_to_lbound: Mapping[str, ArithmeticExpression]
66+
seen_priv_axis_inames: set[str]
67+
68+
def __init__(self,
69+
var_to_new_inames: Mapping[str, Sequence[p.Variable]],
70+
iname_to_lbound: Mapping[str, ArithmeticExpression]) -> None:
5571
self.var_to_new_inames = var_to_new_inames
5672
self.iname_to_lbound = iname_to_lbound
5773
self.seen_priv_axis_inames = set()
5874
super().__init__()
5975

60-
def map_subscript(self, expr: p.Subscript):
76+
@override
77+
def map_subscript(self, expr: p.Subscript, /) -> Expression:
6178
assert isinstance(expr.aggregate, p.Variable)
6279
try:
6380
extra_idx = self.var_to_new_inames[expr.aggregate.name]
@@ -71,32 +88,36 @@ def map_subscript(self, expr: p.Subscript):
7188

7289
self.seen_priv_axis_inames.update(v.name for v in extra_idx)
7390

74-
new_idx = index + tuple(flatten(v - self.iname_to_lbound[v.name])
75-
for v in extra_idx)
91+
new_idx = index + tuple(
92+
flatten(v - self.iname_to_lbound[v.name]) for v in extra_idx
93+
)
7694

7795
if len(new_idx) == 1:
7896
new_idx = new_idx[0]
7997
return expr.aggregate[new_idx]
8098

81-
def map_variable(self, expr: p.Variable):
99+
@override
100+
def map_variable(self, expr: p.Variable, /) -> Expression:
82101
try:
83102
new_idx = self.var_to_new_inames[expr.name]
84103
except KeyError:
85104
return expr
86105
else:
87106
self.seen_priv_axis_inames.update(v.name for v in new_idx)
88107

89-
new_idx = tuple(flatten(v - self.iname_to_lbound[v.name])
90-
for v in new_idx)
91-
108+
new_idx = tuple(flatten(v - self.iname_to_lbound[v.name]) for v in new_idx)
92109
if len(new_idx) == 1:
93110
new_idx = new_idx[0]
111+
94112
return expr[new_idx]
95113

96114

97115
@for_each_kernel
98116
def privatize_temporaries_with_inames(
99-
kernel, privatizing_inames, only_var_names=None):
117+
kernel: LoopKernel,
118+
privatizing_inames: InameStr | InameStrSet,
119+
only_var_names: InameStr | InameStrSet | None = None,
120+
) -> LoopKernel:
100121
"""This function provides each loop iteration of the *privatizing_inames*
101122
with its own private entry in the temporaries it accesses (possibly
102123
restricted to *only_var_names*).
@@ -124,32 +145,32 @@ def privatize_temporaries_with_inames(
124145
end
125146
126147
facilitating loop interchange of the *imatrix* loop.
148+
127149
.. versionadded:: 2018.1
128150
"""
129151

130152
if isinstance(privatizing_inames, str):
131153
privatizing_inames = frozenset(
132-
s.strip()
133-
for s in privatizing_inames.split(","))
154+
s.strip() for s in privatizing_inames.split(",")
155+
)
134156

135157
if isinstance(only_var_names, str):
136158
only_var_names = frozenset(
137-
s.strip()
138-
for s in only_var_names.split(","))
159+
s.strip() for s in only_var_names.split(",")
160+
)
139161

140162
# {{{ sanity checks
141163

142164
if (only_var_names is not None
143165
and privatizing_inames <= kernel.all_inames()
144166
and not (frozenset(only_var_names) <= kernel.all_variable_names())):
145-
raise LoopyError(f"Some variables in '{only_var_names}'"
146-
f" not used in kernel '{kernel.name}'.")
167+
raise LoopyError(f"some variables in '{only_var_names}'"
168+
f" not used in kernel '{kernel.name}'")
147169

148170
# }}}
149171

150172
wmap = kernel.writer_map()
151-
152-
var_to_new_priv_axis_iname = {}
173+
var_to_new_priv_axis_iname: dict[str, frozenset[str]] = {}
153174

154175
# {{{ find variables that need extra indices
155176

@@ -162,27 +183,27 @@ def privatize_temporaries_with_inames(
162183

163184
priv_axis_inames = writer_insn.within_inames & privatizing_inames
164185

165-
referenced_priv_axis_inames = (priv_axis_inames
166-
& writer_insn.write_dependency_names())
186+
referenced_priv_axis_inames = (
187+
priv_axis_inames & writer_insn.write_dependency_names())
167188

168189
new_priv_axis_inames = priv_axis_inames - referenced_priv_axis_inames
169190

170191
if not new_priv_axis_inames:
171192
break
172193

173194
if tv.name in var_to_new_priv_axis_iname:
174-
if new_priv_axis_inames != set(var_to_new_priv_axis_iname[tv.name]):
175-
raise LoopyError("instruction '%s' requires adding "
176-
"indices for privatizing var '%s' on iname(s) '%s', "
177-
"but previous instructions required different "
178-
"inames '%s'"
179-
% (writer_insn_id, tv.name,
180-
", ".join(new_priv_axis_inames),
181-
", ".join(var_to_new_priv_axis_iname[tv.name])))
195+
if new_priv_axis_inames != var_to_new_priv_axis_iname[tv.name]:
196+
new_inames_str = ", ".join(new_priv_axis_inames)
197+
prev_inames_str = ", ".join(var_to_new_priv_axis_iname[tv.name])
198+
raise LoopyError(
199+
f"instruction '{writer_insn_id}' requires adding indices "
200+
"for privatizing var '{tv.name}' on iname(s) "
201+
f"'{new_inames_str}', but previous instructions required "
202+
f"different inames '{prev_inames_str}'")
182203

183204
continue
184205

185-
var_to_new_priv_axis_iname[tv.name] = set(new_priv_axis_inames)
206+
var_to_new_priv_axis_iname[tv.name] = frozenset(new_priv_axis_inames)
186207

187208
# }}}
188209

@@ -191,8 +212,8 @@ def privatize_temporaries_with_inames(
191212
from loopy.isl_helpers import static_max_of_pw_aff
192213
from loopy.symbolic import pw_aff_to_expr
193214

194-
priv_axis_iname_to_length = {}
195-
iname_to_lbound = {}
215+
priv_axis_iname_to_length: dict[str, ArithmeticExpression] = {}
216+
iname_to_lbound: dict[str, ArithmeticExpression] = {}
196217
for priv_axis_inames in var_to_new_priv_axis_iname.values():
197218
for iname in priv_axis_inames:
198219
if iname in priv_axis_iname_to_length:
@@ -209,7 +230,7 @@ def privatize_temporaries_with_inames(
209230

210231
from loopy.kernel.data import VectorizeTag
211232

212-
new_temp_vars = kernel.temporary_variables.copy()
233+
new_temp_vars = dict(kernel.temporary_variables)
213234
for tv_name, inames in var_to_new_priv_axis_iname.items():
214235
tv = new_temp_vars[tv_name]
215236
extra_shape = tuple(priv_axis_iname_to_length[iname] for iname in inames)
@@ -218,31 +239,32 @@ def privatize_temporaries_with_inames(
218239
if shape is None:
219240
shape = ()
220241

221-
dim_tags = ["c"] * (len(shape) + len(extra_shape))
242+
# NOTE: could be auto?
243+
assert isinstance(shape, tuple)
244+
ndim = len(shape)
245+
246+
dim_tags = ["c"] * (ndim + len(extra_shape))
222247
for i, iname in enumerate(inames):
223248
if kernel.iname_tags_of_type(iname, VectorizeTag):
224-
dim_tags[len(shape) + i] = "vec"
249+
dim_tags[ndim + i] = "vec"
225250

226251
base_indices = tv.base_indices
227252
if base_indices is not None:
228253
base_indices = base_indices + tuple([0]*len(extra_shape))
229254

230255
new_temp_vars[tv.name] = tv.copy(shape=shape + extra_shape,
231256
base_indices=base_indices,
232-
# Forget what you knew about data layout,
233-
# create from scratch.
257+
# Forget what you knew about data layout, create from scratch.
234258
dim_tags=dim_tags,
235259
dim_names=None)
236260

237261
# }}}
238262

239-
from pymbolic import var
240263
var_to_extra_iname = {
241-
var_name: tuple(var(iname) for iname in inames)
264+
var_name: tuple(p.Variable(iname) for iname in inames)
242265
for var_name, inames in var_to_new_priv_axis_iname.items()}
243266

244-
new_insns = []
245-
267+
new_insns: list[InstructionBase] = []
246268
for insn in kernel.instructions:
247269
eiii = ExtraInameIndexInserter(var_to_extra_iname,
248270
iname_to_lbound)
@@ -269,25 +291,34 @@ def privatize_temporaries_with_inames(
269291
# {{{ unprivatize temporaries with iname
270292

271293
class _InameRemover(IdentityMapper[[bool]]):
272-
def __init__(self, inames_to_remove, only_var_names):
294+
only_var_names: frozenset[str] | None
295+
inames_to_remove: frozenset[str]
296+
var_name_to_remove_indices: dict[str, dict[int, str]]
297+
298+
def __init__(self,
299+
inames_to_remove: frozenset[str],
300+
only_var_names: frozenset[str] | None) -> None:
273301
self.only_var_names = only_var_names
274302
self.inames_to_remove = inames_to_remove
275303
self.var_name_to_remove_indices = {}
276304
super().__init__()
277305

278-
def map_subscript(self, expr: p.Subscript, in_subscript: bool = False):
306+
@override
307+
def map_subscript(self, expr: p.Subscript, /,
308+
in_subscript: bool = False) -> Expression:
279309
assert isinstance(expr.aggregate, p.Variable)
280310
name = expr.aggregate.name
311+
281312
if not self.only_var_names or name in self.only_var_names:
282313
index = expr.index
283314
if not isinstance(index, tuple):
284315
index = (index,)
285316

286-
remove_indices = {}
287-
new_index = []
317+
remove_indices: dict[int, str] = {}
318+
new_index: list[Expression] = []
288319
for i, index_expr in enumerate(index):
289-
if isinstance(index_expr, pymbolic.primitives.Variable) and \
290-
index_expr.name in self.inames_to_remove:
320+
if (isinstance(index_expr, p.Variable)
321+
and index_expr.name in self.inames_to_remove):
291322
remove_indices[i] = index_expr.name
292323
else:
293324
new_index.append(index_expr)
@@ -303,8 +334,9 @@ def map_subscript(self, expr: p.Subscript, in_subscript: bool = False):
303334
self.var_name_to_remove_indices[name] = remove_indices
304335

305336
if new_index:
306-
new_index = new_index[0] if len(new_index) == 1 else tuple(new_index)
307-
return expr.aggregate[new_index]
337+
return expr.aggregate[
338+
new_index[0] if len(new_index) == 1 else tuple(new_index)
339+
]
308340
else:
309341
return expr.aggregate
310342
else:
@@ -313,7 +345,9 @@ def map_subscript(self, expr: p.Subscript, in_subscript: bool = False):
313345

314346
@for_each_kernel
315347
def unprivatize_temporaries_with_inames(
316-
kernel, privatizing_inames, only_var_names=None):
348+
kernel: LoopKernel,
349+
privatizing_inames: InameStr | InameStrSet,
350+
only_var_names: InameStr | InameStrSet | None = None) -> LoopKernel:
317351
"""This function reverses the effects of
318352
:func:`privatize_temporaries_with_inames` and removes the private entries
319353
in the temporaries each loop iteration of the *privatizing_inames*
@@ -342,13 +376,13 @@ def unprivatize_temporaries_with_inames(
342376

343377
if isinstance(privatizing_inames, str):
344378
privatizing_inames = frozenset(
345-
s.strip()
346-
for s in privatizing_inames.split(","))
379+
s.strip() for s in privatizing_inames.split(",")
380+
)
347381

348382
if isinstance(only_var_names, str):
349383
only_var_names = frozenset(
350-
s.strip()
351-
for s in only_var_names.split(","))
384+
s.strip() for s in only_var_names.split(",")
385+
)
352386

353387
# {{{ sanity checks
354388

@@ -372,18 +406,20 @@ def unprivatize_temporaries_with_inames(
372406

373407
from loopy.kernel.array import VectorArrayDimTag
374408

375-
new_temp_vars = kernel.temporary_variables.copy()
409+
new_temp_vars = dict(kernel.temporary_variables)
376410
for tv_name, tv in new_temp_vars.items():
377411
remove_indices = var_name_to_remove_indices.get(tv_name, {})
378412
new_shape = tv.shape
379413
if new_shape is not None:
380-
new_shape = tuple(dim for idim, dim in enumerate(new_shape)
414+
assert isinstance(new_shape, tuple)
415+
new_shape = tuple(
416+
dim for idim, dim in enumerate(new_shape)
381417
if idim not in remove_indices)
382418

383419
new_dim_tags = tv.dim_tags
384420
if new_dim_tags is not None:
385421
new_dim_tags = ["vec" if isinstance(dim_tag, VectorArrayDimTag) else "c"
386-
for idim, dim_tag in enumerate(new_dim_tags)]
422+
for _idim, dim_tag in enumerate(new_dim_tags)]
387423
new_dim_tags = tuple(dim for idim, dim in enumerate(new_dim_tags)
388424
if idim not in remove_indices)
389425

0 commit comments

Comments
 (0)