6464import enum
6565import itertools
6666from dataclasses import dataclass
67- from functools import cached_property , reduce
68- from typing import TYPE_CHECKING , TypeAlias
67+ from functools import cached_property
68+ from typing import TYPE_CHECKING , Literal , TypeAlias , cast
6969
7070from constantdict import constantdict
71+ from typing_extensions import override
7172
7273import islpy as isl
7374from pytools import fset_union , memoize_method , memoize_on_first_arg
7475
7576from loopy .diagnostic import LoopyError
7677from loopy .kernel .data import AddressSpace , ArrayArg , TemporaryVariable
7778from loopy .schedule .tree import Tree
78- from loopy .typing import InameStr , InameStrSet , not_none
79+ from loopy .typing import InameStr , InameStrSet , InsnId , not_none
7980
8081
8182if TYPE_CHECKING :
8283 from collections .abc import Callable , Collection , Mapping , Sequence , Set
8384
85+ from pymbolic import Expression
86+
8487 from loopy .kernel import LoopKernel
88+ from loopy .kernel .instruction import InstructionBase
8589 from loopy .schedule import ScheduleItem
90+ from loopy .translation_unit import CallablesTable
8691
8792
8893# {{{ block boundary finder
@@ -94,8 +99,8 @@ def get_block_boundaries(schedule: Sequence[ScheduleItem]) -> Mapping[int, int]:
9499 :class:`loopy.schedule.EndBlockItem`\ s and vice versa.
95100 """
96101 from loopy .schedule import BeginBlockItem , EndBlockItem
97- block_bounds = {}
98- active_blocks = []
102+ block_bounds : dict [ int , int ] = {}
103+ active_blocks : list [ int ] = []
99104 for idx , sched_item in enumerate (schedule ):
100105 if isinstance (sched_item , BeginBlockItem ):
101106 active_blocks .append (idx )
@@ -114,8 +119,8 @@ def temporaries_read_in_subkernel(
114119 kernel : LoopKernel , subkernel_name : str ) -> frozenset [str ]:
115120 from loopy .kernel .tools import get_subkernel_to_insn_id_map
116121 insn_ids = get_subkernel_to_insn_id_map (kernel )[subkernel_name ]
117- inames = frozenset (). union ( * (kernel .insn_inames (insn_id )
118- for insn_id in insn_ids ))
122+ inames = fset_union (kernel .insn_inames (insn_id )
123+ for insn_id in insn_ids )
119124 domain_idxs = {kernel .get_home_domain_index (iname ) for iname in inames }
120125 params = fset_union (
121126 kernel .domains [dom_idx ].get_var_names_not_none (isl .dim_type .param )
@@ -142,12 +147,12 @@ def args_read_in_subkernel(
142147 kernel : LoopKernel , subkernel_name : str ) -> frozenset [str ]:
143148 from loopy .kernel .tools import get_subkernel_to_insn_id_map
144149 insn_ids = get_subkernel_to_insn_id_map (kernel )[subkernel_name ]
145- inames = frozenset (). union ( * (kernel .insn_inames (insn_id )
146- for insn_id in insn_ids ))
150+ inames = fset_union (kernel .insn_inames (insn_id )
151+ for insn_id in insn_ids )
147152 domain_idxs = {kernel .get_home_domain_index (iname ) for iname in inames }
148- params = frozenset (). union ( * (
153+ params = fset_union (
149154 kernel .domains [dom_idx ].get_var_names_not_none (isl .dim_type .param )
150- for dom_idx in domain_idxs ))
155+ for dom_idx in domain_idxs )
151156 return (frozenset (arg
152157 for insn_id in insn_ids
153158 for arg in kernel .id_to_insn [insn_id ].read_dependency_names ()
@@ -209,6 +214,7 @@ class SubKernelArgInfo(KernelArgInfo):
209214 passed_temporaries : Sequence [str ]
210215
211216 @property
217+ @override
212218 def passed_names (self ) -> Sequence [str ]:
213219 return (list (self .passed_arg_names )
214220 + list (self .passed_inames )
@@ -221,10 +227,10 @@ def _should_temp_var_be_passed(tv: TemporaryVariable) -> bool:
221227
222228class _SupportingNameTracker :
223229 def __init__ (self , kernel : LoopKernel ):
224- self .kernel = kernel
225- self .name_to_main_name : dict [str , str ] = {}
230+ self .kernel : LoopKernel = kernel
231+ self .name_to_main_name : dict [str , Set [ str ] ] = {}
226232
227- def add_supporting_names_for (self , name ):
233+ def add_supporting_names_for (self , name : str ):
228234 var_descr = self .kernel .get_var_descriptor (name )
229235 for supp_name in var_descr .supporting_names ():
230236 self .name_to_main_name [supp_name ] = (
@@ -234,8 +240,8 @@ def add_supporting_names_for(self, name):
234240 def get_additional_args_and_tvs (
235241 self , already_passed : set [str ]
236242 ) -> tuple [list [str ], list [str ]]:
237- additional_args = []
238- additional_temporaries = []
243+ additional_args : list [ str ] = []
244+ additional_temporaries : list [ str ] = []
239245
240246 for supporting_name in sorted (frozenset (self .name_to_main_name )):
241247 if supporting_name not in already_passed :
@@ -257,7 +263,7 @@ def _process_args_for_arg_info(
257263
258264 args_expected : set [str ] = set ()
259265
260- passed_arg_names = []
266+ passed_arg_names : list [ str ] = []
261267 for arg in kernel .args :
262268 if used_only and not (arg .name in args_read or arg .name in args_written ):
263269 continue
@@ -401,8 +407,15 @@ def get_return_from_kernel_mapping(kernel: LoopKernel) -> Mapping[int, int | Non
401407
402408# {{{ check for write races in accesses
403409
404- def _check_for_access_races (map_a , insn_a , map_b , insn_b , knl , callables_table ,
405- address_space ):
410+ def _check_for_access_races (
411+ map_a : isl .Map ,
412+ insn_a : InstructionBase ,
413+ map_b : isl .Map ,
414+ insn_b : InstructionBase ,
415+ knl : LoopKernel ,
416+ callables_table : CallablesTable ,
417+ address_space : AddressSpace
418+ ):
406419 """
407420 Returns *True* if the execution instances of *insn_a* and *insn_b*, accessing
408421 the same variable via access maps *map_a* and *map_b*, result in an access race.
@@ -439,7 +452,7 @@ def _check_for_access_races(map_a, insn_a, map_b, insn_b, knl, callables_table,
439452 # Step 1.4: Rename the dims with their iname tags i.e. (g.i or l.i)
440453 # Step 1.5: Name the ith output dims as _lp_dim{i}
441454
442- updated_maps = []
455+ updated_maps : list [ isl . Map ] = []
443456
444457 for (map_ , insn ) in [
445458 (map_a , insn_a ),
@@ -596,29 +609,30 @@ class AccessMapDescriptor(enum.Enum):
596609class WriteRaceChecker :
597610 """Used for checking for overlap between access ranges of instructions."""
598611
599- def __init__ (self , kernel , callables_table ):
600- self .kernel = kernel
601- self .callables_table = callables_table
612+ def __init__ (self , kernel : LoopKernel , callables_table : CallablesTable ):
613+ self .kernel : LoopKernel = kernel
614+ self .callables_table : CallablesTable = callables_table
602615
603616 @cached_property
604617 def vars (self ):
605618 return (self .kernel .get_written_variables ()
606619 | self .kernel .get_read_variables ())
607620
608621 @memoize_method
609- def _get_access_maps (self , insn_id , access_dir ):
622+ def _get_access_maps (self , insn_id : InsnId , access_dir : Literal [ "w" , "any" ] ):
610623 from collections import defaultdict
611624
612625 from loopy .symbolic import BatchedAccessMapMapper
613626
614627 insn = self .kernel .id_to_insn [insn_id ]
615628
616- exprs = list (insn .assignees )
629+ exprs : list [ Expression ] = list (insn .assignees )
617630 if access_dir == "any" :
618631 exprs .append (insn .expression )
619632 exprs .extend (insn .predicates )
620633
621- access_maps = defaultdict (lambda : AccessMapDescriptor .DOES_NOT_ACCESS )
634+ access_maps : dict [str , AccessMapDescriptor | isl .Map ] = defaultdict (
635+ lambda : AccessMapDescriptor .DOES_NOT_ACCESS )
622636
623637 arm = BatchedAccessMapMapper (self .kernel , self .vars , overestimate = True )
624638
@@ -629,11 +643,15 @@ def _get_access_maps(self, insn_id, access_dir):
629643 if arm .bad_subscripts [name ]:
630644 access_maps [name ] = AccessMapDescriptor .NON_AFFINE_ACCESS
631645 continue
632- access_maps [name ] = arm .access_maps [name ][insn .within_inames ]
646+ access_maps [name ] = not_none ( arm .access_maps [name ][insn .within_inames ])
633647
634648 return access_maps
635649
636- def _get_access_map_for_var (self , insn_id , access_dir , var_name ):
650+ def _get_access_map_for_var (self ,
651+ insn_id : InsnId ,
652+ access_dir : Literal ["w" , "any" ],
653+ var_name : str
654+ ):
637655 assert access_dir in ["w" , "any" ]
638656
639657 insn = self .kernel .id_to_insn [insn_id ]
@@ -642,14 +660,25 @@ def _get_access_map_for_var(self, insn_id, access_dir, var_name):
642660 from loopy .kernel .instruction import MultiAssignmentBase
643661 if not isinstance (insn , MultiAssignmentBase ):
644662 if access_dir == "any" :
645- return var_name in insn .dependency_names ()
663+ if var_name in insn .dependency_names ():
664+ return AccessMapDescriptor .NON_AFFINE_ACCESS
665+ else :
666+ return AccessMapDescriptor .DOES_NOT_ACCESS
646667 else :
647- return var_name in insn .write_dependency_names ()
668+ if var_name in insn .write_dependency_names ():
669+ return AccessMapDescriptor .NON_AFFINE_ACCESS
670+ else :
671+ return AccessMapDescriptor .DOES_NOT_ACCESS
648672
649673 return self ._get_access_maps (insn_id , access_dir )[var_name ]
650674
651- def do_accesses_result_in_races (self , insn1 , insn1_dir , insn2 , insn2_dir ,
652- var_name ):
675+ def do_accesses_result_in_races (self ,
676+ insn1 : InsnId ,
677+ insn1_dir : Literal ["w" , "any" ],
678+ insn2 : InsnId ,
679+ insn2_dir : Literal ["w" , "any" ],
680+ var_name : str
681+ ):
653682 """Determine whether the access maps to *var_name* in the two given
654683 instructions result in write races owing to concurrent iname tags. This
655684 determination is made 'conservatively', i.e. if precise information is
@@ -675,7 +704,7 @@ def do_accesses_result_in_races(self, insn1, insn1_dir, insn2, insn2_dir,
675704 return _check_for_access_races (insn1_amap , self .kernel .id_to_insn [insn1 ],
676705 insn2_amap , self .kernel .id_to_insn [insn2 ],
677706 self .kernel , self .callables_table ,
678- ( self .kernel
707+ cast ( "AddressSpace" , self .kernel
679708 .get_var_descriptor (var_name )
680709 .address_space ))
681710
@@ -741,10 +770,7 @@ def separate_loop_nest(
741770 """
742771 assert all (isinstance (loop_nest , frozenset ) for loop_nest in loop_nests )
743772
744- # annotation to avoid https://github.com/python/mypy/issues/17693
745- emptyset : InameStrSet = frozenset ()
746-
747- assert inames_to_separate <= reduce (frozenset .union , loop_nests , emptyset )
773+ assert inames_to_separate <= fset_union (loop_nests )
748774
749775 # {{{ sanity check to ensure the loop nest *inames_to_separate* is possible
750776
@@ -760,8 +786,7 @@ def separate_loop_nest(
760786 # }}}
761787
762788 innermost_node = loop_nests [- 1 ]
763- # separate variable to avoid https://github.com/python/mypy/issues/17694
764- outerer_loops = reduce (frozenset .union , loop_nests [:- 1 ], emptyset )
789+ outerer_loops = fset_union (loop_nests [:- 1 ])
765790 new_outer_node = inames_to_separate - outerer_loops
766791 new_inner_node = innermost_node - inames_to_separate
767792
@@ -783,7 +808,11 @@ def separate_loop_nest(
783808 return tree , new_outer_node , new_inner_node
784809
785810
786- def _add_inner_loops (tree , outer_loop_nest , inner_loop_nest ):
811+ def _add_inner_loops (
812+ tree : LoopNestTree ,
813+ outer_loop_nest : InameStrSet ,
814+ inner_loop_nest : InameStrSet
815+ ) -> LoopNestTree :
787816 """
788817 Returns a copy of *tree* that nests *inner_loop_nest* inside *outer_loop_nest*.
789818 """
@@ -889,7 +918,7 @@ def _update_nesting_constraints(
889918 " schedule kernels with priority dependencies"
890919 " between sibling loop nests" )
891920
892- def _raise_loopy_err (x ):
921+ def _raise_loopy_err (x : str ):
893922 raise LoopyError (x )
894923
895924 # record strict priorities
@@ -911,9 +940,9 @@ def _raise_loopy_err(x):
911940
912941 assert loop_nest_tree .root == frozenset ()
913942
914- new_tree = Tree .from_root ("" )
943+ new_tree = Tree [ InameStr ] .from_root ("" )
915944
916- old_to_new_parent = {}
945+ old_to_new_parent : dict [ InameStrSet , InameStr ] = {}
917946
918947 old_to_new_parent [loop_nest_tree .root ] = ""
919948
@@ -1054,7 +1083,7 @@ def _get_iname_to_tree_node_id_from_partial_loop_nest_tree(
10541083
10551084 :arg tree: A partial loop nest tree.
10561085 """
1057- iname_to_tree_node_id = {}
1086+ iname_to_tree_node_id : dict [ InameStr , InameStrSet ] = {}
10581087 for node in tree .nodes ():
10591088 assert isinstance (node , frozenset )
10601089 for iname in node :
0 commit comments