2323THE SOFTWARE.
2424"""
2525
26-
2726import logging
27+ from typing import TYPE_CHECKING
28+
29+ from typing_extensions import override
2830
29- import pymbolic
3031import pymbolic .primitives as p
3132
3233from loopy .diagnostic import LoopyError
3334from loopy .translation_unit import for_each_kernel
3435
3536
37+ if TYPE_CHECKING :
38+ from collections .abc import Mapping , Sequence
39+
40+ from pymbolic .typing import ArithmeticExpression , Expression
41+
42+ from loopy .kernel import LoopKernel
43+ from loopy .kernel .instruction import InstructionBase
44+ from loopy .typing import InameStr , InameStrSet
45+
3646logger = logging .getLogger (__name__ )
3747
3848
5161
5262
5363class ExtraInameIndexInserter (IdentityMapper [[]]):
54- def __init__ (self , var_to_new_inames , iname_to_lbound ):
64+ var_to_new_inames : Mapping [str , Sequence [p .Variable ]]
65+ iname_to_lbound : Mapping [str , ArithmeticExpression ]
66+ seen_priv_axis_inames : set [str ]
67+
68+ def __init__ (self ,
69+ var_to_new_inames : Mapping [str , Sequence [p .Variable ]],
70+ iname_to_lbound : Mapping [str , ArithmeticExpression ]) -> None :
5571 self .var_to_new_inames = var_to_new_inames
5672 self .iname_to_lbound = iname_to_lbound
5773 self .seen_priv_axis_inames = set ()
5874 super ().__init__ ()
5975
60- def map_subscript (self , expr : p .Subscript ):
76+ @override
77+ def map_subscript (self , expr : p .Subscript , / ) -> Expression :
6178 assert isinstance (expr .aggregate , p .Variable )
6279 try :
6380 extra_idx = self .var_to_new_inames [expr .aggregate .name ]
@@ -71,32 +88,36 @@ def map_subscript(self, expr: p.Subscript):
7188
7289 self .seen_priv_axis_inames .update (v .name for v in extra_idx )
7390
74- new_idx = index + tuple (flatten (v - self .iname_to_lbound [v .name ])
75- for v in extra_idx )
91+ new_idx = index + tuple (
92+ flatten (v - self .iname_to_lbound [v .name ]) for v in extra_idx
93+ )
7694
7795 if len (new_idx ) == 1 :
7896 new_idx = new_idx [0 ]
7997 return expr .aggregate [new_idx ]
8098
81- def map_variable (self , expr : p .Variable ):
99+ @override
100+ def map_variable (self , expr : p .Variable , / ) -> Expression :
82101 try :
83102 new_idx = self .var_to_new_inames [expr .name ]
84103 except KeyError :
85104 return expr
86105 else :
87106 self .seen_priv_axis_inames .update (v .name for v in new_idx )
88107
89- new_idx = tuple (flatten (v - self .iname_to_lbound [v .name ])
90- for v in new_idx )
91-
108+ new_idx = tuple (flatten (v - self .iname_to_lbound [v .name ]) for v in new_idx )
92109 if len (new_idx ) == 1 :
93110 new_idx = new_idx [0 ]
111+
94112 return expr [new_idx ]
95113
96114
97115@for_each_kernel
98116def privatize_temporaries_with_inames (
99- kernel , privatizing_inames , only_var_names = None ):
117+ kernel : LoopKernel ,
118+ privatizing_inames : InameStr | InameStrSet ,
119+ only_var_names : InameStr | InameStrSet | None = None ,
120+ ) -> LoopKernel :
100121 """This function provides each loop iteration of the *privatizing_inames*
101122 with its own private entry in the temporaries it accesses (possibly
102123 restricted to *only_var_names*).
@@ -124,32 +145,32 @@ def privatize_temporaries_with_inames(
124145 end
125146
126147 facilitating loop interchange of the *imatrix* loop.
148+
127149 .. versionadded:: 2018.1
128150 """
129151
130152 if isinstance (privatizing_inames , str ):
131153 privatizing_inames = frozenset (
132- s . strip ( )
133- for s in privatizing_inames . split ( "," ) )
154+ s . strip () for s in privatizing_inames . split ( "," )
155+ )
134156
135157 if isinstance (only_var_names , str ):
136158 only_var_names = frozenset (
137- s . strip ( )
138- for s in only_var_names . split ( "," ) )
159+ s . strip () for s in only_var_names . split ( "," )
160+ )
139161
140162 # {{{ sanity checks
141163
142164 if (only_var_names is not None
143165 and privatizing_inames <= kernel .all_inames ()
144166 and not (frozenset (only_var_names ) <= kernel .all_variable_names ())):
145- raise LoopyError (f"Some variables in '{ only_var_names } '"
146- f" not used in kernel '{ kernel .name } '. " )
167+ raise LoopyError (f"some variables in '{ only_var_names } '"
168+ f" not used in kernel '{ kernel .name } '" )
147169
148170 # }}}
149171
150172 wmap = kernel .writer_map ()
151-
152- var_to_new_priv_axis_iname = {}
173+ var_to_new_priv_axis_iname : dict [str , frozenset [str ]] = {}
153174
154175 # {{{ find variables that need extra indices
155176
@@ -162,27 +183,27 @@ def privatize_temporaries_with_inames(
162183
163184 priv_axis_inames = writer_insn .within_inames & privatizing_inames
164185
165- referenced_priv_axis_inames = (priv_axis_inames
166- & writer_insn .write_dependency_names ())
186+ referenced_priv_axis_inames = (
187+ priv_axis_inames & writer_insn .write_dependency_names ())
167188
168189 new_priv_axis_inames = priv_axis_inames - referenced_priv_axis_inames
169190
170191 if not new_priv_axis_inames :
171192 break
172193
173194 if tv .name in var_to_new_priv_axis_iname :
174- if new_priv_axis_inames != set ( var_to_new_priv_axis_iname [tv .name ]) :
175- raise LoopyError ( "instruction '%s' requires adding "
176- "indices for privatizing var '%s' on iname(s) '%s' , "
177- "but previous instructions required different "
178- "inames '%s' "
179- % ( writer_insn_id , tv . name ,
180- ", " . join ( new_priv_axis_inames ),
181- ", " . join ( var_to_new_priv_axis_iname [ tv . name ])) )
195+ if new_priv_axis_inames != var_to_new_priv_axis_iname [tv .name ]:
196+ new_inames_str = ", " . join ( new_priv_axis_inames )
197+ prev_inames_str = " , ". join ( var_to_new_priv_axis_iname [ tv . name ])
198+ raise LoopyError (
199+ f"instruction ' { writer_insn_id } ' requires adding indices "
200+ "for privatizing var '{tv.name}' on iname(s) "
201+ f"' { new_inames_str } ', but previous instructions required "
202+ f"different inames ' { prev_inames_str } '" )
182203
183204 continue
184205
185- var_to_new_priv_axis_iname [tv .name ] = set (new_priv_axis_inames )
206+ var_to_new_priv_axis_iname [tv .name ] = frozenset (new_priv_axis_inames )
186207
187208 # }}}
188209
@@ -191,8 +212,8 @@ def privatize_temporaries_with_inames(
191212 from loopy .isl_helpers import static_max_of_pw_aff
192213 from loopy .symbolic import pw_aff_to_expr
193214
194- priv_axis_iname_to_length = {}
195- iname_to_lbound = {}
215+ priv_axis_iname_to_length : dict [ str , ArithmeticExpression ] = {}
216+ iname_to_lbound : dict [ str , ArithmeticExpression ] = {}
196217 for priv_axis_inames in var_to_new_priv_axis_iname .values ():
197218 for iname in priv_axis_inames :
198219 if iname in priv_axis_iname_to_length :
@@ -209,7 +230,7 @@ def privatize_temporaries_with_inames(
209230
210231 from loopy .kernel .data import VectorizeTag
211232
212- new_temp_vars = kernel .temporary_variables . copy ( )
233+ new_temp_vars = dict ( kernel .temporary_variables )
213234 for tv_name , inames in var_to_new_priv_axis_iname .items ():
214235 tv = new_temp_vars [tv_name ]
215236 extra_shape = tuple (priv_axis_iname_to_length [iname ] for iname in inames )
@@ -218,31 +239,32 @@ def privatize_temporaries_with_inames(
218239 if shape is None :
219240 shape = ()
220241
221- dim_tags = ["c" ] * (len (shape ) + len (extra_shape ))
242+ # NOTE: could be auto?
243+ assert isinstance (shape , tuple )
244+ ndim = len (shape )
245+
246+ dim_tags = ["c" ] * (ndim + len (extra_shape ))
222247 for i , iname in enumerate (inames ):
223248 if kernel .iname_tags_of_type (iname , VectorizeTag ):
224- dim_tags [len ( shape ) + i ] = "vec"
249+ dim_tags [ndim + i ] = "vec"
225250
226251 base_indices = tv .base_indices
227252 if base_indices is not None :
228253 base_indices = base_indices + tuple ([0 ]* len (extra_shape ))
229254
230255 new_temp_vars [tv .name ] = tv .copy (shape = shape + extra_shape ,
231256 base_indices = base_indices ,
232- # Forget what you knew about data layout,
233- # create from scratch.
257+ # Forget what you knew about data layout, create from scratch.
234258 dim_tags = dim_tags ,
235259 dim_names = None )
236260
237261 # }}}
238262
239- from pymbolic import var
240263 var_to_extra_iname = {
241- var_name : tuple (var (iname ) for iname in inames )
264+ var_name : tuple (p . Variable (iname ) for iname in inames )
242265 for var_name , inames in var_to_new_priv_axis_iname .items ()}
243266
244- new_insns = []
245-
267+ new_insns : list [InstructionBase ] = []
246268 for insn in kernel .instructions :
247269 eiii = ExtraInameIndexInserter (var_to_extra_iname ,
248270 iname_to_lbound )
@@ -269,25 +291,34 @@ def privatize_temporaries_with_inames(
269291# {{{ unprivatize temporaries with iname
270292
271293class _InameRemover (IdentityMapper [[bool ]]):
272- def __init__ (self , inames_to_remove , only_var_names ):
294+ only_var_names : frozenset [str ] | None
295+ inames_to_remove : frozenset [str ]
296+ var_name_to_remove_indices : dict [str , dict [int , str ]]
297+
298+ def __init__ (self ,
299+ inames_to_remove : frozenset [str ],
300+ only_var_names : frozenset [str ] | None ) -> None :
273301 self .only_var_names = only_var_names
274302 self .inames_to_remove = inames_to_remove
275303 self .var_name_to_remove_indices = {}
276304 super ().__init__ ()
277305
278- def map_subscript (self , expr : p .Subscript , in_subscript : bool = False ):
306+ @override
307+ def map_subscript (self , expr : p .Subscript , / ,
308+ in_subscript : bool = False ) -> Expression :
279309 assert isinstance (expr .aggregate , p .Variable )
280310 name = expr .aggregate .name
311+
281312 if not self .only_var_names or name in self .only_var_names :
282313 index = expr .index
283314 if not isinstance (index , tuple ):
284315 index = (index ,)
285316
286- remove_indices = {}
287- new_index = []
317+ remove_indices : dict [ int , str ] = {}
318+ new_index : list [ Expression ] = []
288319 for i , index_expr in enumerate (index ):
289- if isinstance (index_expr , pymbolic . primitives . Variable ) and \
290- index_expr .name in self .inames_to_remove :
320+ if ( isinstance (index_expr , p . Variable )
321+ and index_expr .name in self .inames_to_remove ) :
291322 remove_indices [i ] = index_expr .name
292323 else :
293324 new_index .append (index_expr )
@@ -303,8 +334,9 @@ def map_subscript(self, expr: p.Subscript, in_subscript: bool = False):
303334 self .var_name_to_remove_indices [name ] = remove_indices
304335
305336 if new_index :
306- new_index = new_index [0 ] if len (new_index ) == 1 else tuple (new_index )
307- return expr .aggregate [new_index ]
337+ return expr .aggregate [
338+ new_index [0 ] if len (new_index ) == 1 else tuple (new_index )
339+ ]
308340 else :
309341 return expr .aggregate
310342 else :
@@ -313,7 +345,9 @@ def map_subscript(self, expr: p.Subscript, in_subscript: bool = False):
313345
314346@for_each_kernel
315347def unprivatize_temporaries_with_inames (
316- kernel , privatizing_inames , only_var_names = None ):
348+ kernel : LoopKernel ,
349+ privatizing_inames : InameStr | InameStrSet ,
350+ only_var_names : InameStr | InameStrSet | None = None ) -> LoopKernel :
317351 """This function reverses the effects of
318352 :func:`privatize_temporaries_with_inames` and removes the private entries
319353 in the temporaries each loop iteration of the *privatizing_inames*
@@ -342,13 +376,13 @@ def unprivatize_temporaries_with_inames(
342376
343377 if isinstance (privatizing_inames , str ):
344378 privatizing_inames = frozenset (
345- s . strip ( )
346- for s in privatizing_inames . split ( "," ) )
379+ s . strip () for s in privatizing_inames . split ( "," )
380+ )
347381
348382 if isinstance (only_var_names , str ):
349383 only_var_names = frozenset (
350- s . strip ( )
351- for s in only_var_names . split ( "," ) )
384+ s . strip () for s in only_var_names . split ( "," )
385+ )
352386
353387 # {{{ sanity checks
354388
@@ -372,18 +406,20 @@ def unprivatize_temporaries_with_inames(
372406
373407 from loopy .kernel .array import VectorArrayDimTag
374408
375- new_temp_vars = kernel .temporary_variables . copy ( )
409+ new_temp_vars = dict ( kernel .temporary_variables )
376410 for tv_name , tv in new_temp_vars .items ():
377411 remove_indices = var_name_to_remove_indices .get (tv_name , {})
378412 new_shape = tv .shape
379413 if new_shape is not None :
380- new_shape = tuple (dim for idim , dim in enumerate (new_shape )
414+ assert isinstance (new_shape , tuple )
415+ new_shape = tuple (
416+ dim for idim , dim in enumerate (new_shape )
381417 if idim not in remove_indices )
382418
383419 new_dim_tags = tv .dim_tags
384420 if new_dim_tags is not None :
385421 new_dim_tags = ["vec" if isinstance (dim_tag , VectorArrayDimTag ) else "c"
386- for idim , dim_tag in enumerate (new_dim_tags )]
422+ for _idim , dim_tag in enumerate (new_dim_tags )]
387423 new_dim_tags = tuple (dim for idim , dim in enumerate (new_dim_tags )
388424 if idim not in remove_indices )
389425
0 commit comments