Skip to content

Commit 1aebaa7

Browse files
committed
More annotations in schedule.tools
1 parent cdce11a commit 1aebaa7

File tree

1 file changed

+73
-44
lines changed

1 file changed

+73
-44
lines changed

loopy/schedule/tools.py

Lines changed: 73 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -64,25 +64,30 @@
6464
import enum
6565
import itertools
6666
from dataclasses import dataclass
67-
from functools import cached_property, reduce
68-
from typing import TYPE_CHECKING, TypeAlias
67+
from functools import cached_property
68+
from typing import TYPE_CHECKING, Literal, TypeAlias, cast
6969

7070
from constantdict import constantdict
71+
from typing_extensions import override
7172

7273
import islpy as isl
7374
from pytools import fset_union, memoize_method, memoize_on_first_arg
7475

7576
from loopy.diagnostic import LoopyError
7677
from loopy.kernel.data import AddressSpace, ArrayArg, TemporaryVariable
7778
from loopy.schedule.tree import Tree
78-
from loopy.typing import InameStr, InameStrSet, not_none
79+
from loopy.typing import InameStr, InameStrSet, InsnId, not_none
7980

8081

8182
if TYPE_CHECKING:
8283
from collections.abc import Callable, Collection, Mapping, Sequence, Set
8384

85+
from pymbolic import Expression
86+
8487
from loopy.kernel import LoopKernel
88+
from loopy.kernel.instruction import InstructionBase
8589
from loopy.schedule import ScheduleItem
90+
from loopy.translation_unit import CallablesTable
8691

8792

8893
# {{{ block boundary finder
@@ -94,8 +99,8 @@ def get_block_boundaries(schedule: Sequence[ScheduleItem]) -> Mapping[int, int]:
9499
:class:`loopy.schedule.EndBlockItem`\ s and vice versa.
95100
"""
96101
from loopy.schedule import BeginBlockItem, EndBlockItem
97-
block_bounds = {}
98-
active_blocks = []
102+
block_bounds: dict[int, int] = {}
103+
active_blocks: list[int] = []
99104
for idx, sched_item in enumerate(schedule):
100105
if isinstance(sched_item, BeginBlockItem):
101106
active_blocks.append(idx)
@@ -114,8 +119,8 @@ def temporaries_read_in_subkernel(
114119
kernel: LoopKernel, subkernel_name: str) -> frozenset[str]:
115120
from loopy.kernel.tools import get_subkernel_to_insn_id_map
116121
insn_ids = get_subkernel_to_insn_id_map(kernel)[subkernel_name]
117-
inames = frozenset().union(*(kernel.insn_inames(insn_id)
118-
for insn_id in insn_ids))
122+
inames = fset_union(kernel.insn_inames(insn_id)
123+
for insn_id in insn_ids)
119124
domain_idxs = {kernel.get_home_domain_index(iname) for iname in inames}
120125
params = fset_union(
121126
kernel.domains[dom_idx].get_var_names_not_none(isl.dim_type.param)
@@ -142,12 +147,12 @@ def args_read_in_subkernel(
142147
kernel: LoopKernel, subkernel_name: str) -> frozenset[str]:
143148
from loopy.kernel.tools import get_subkernel_to_insn_id_map
144149
insn_ids = get_subkernel_to_insn_id_map(kernel)[subkernel_name]
145-
inames = frozenset().union(*(kernel.insn_inames(insn_id)
146-
for insn_id in insn_ids))
150+
inames = fset_union(kernel.insn_inames(insn_id)
151+
for insn_id in insn_ids)
147152
domain_idxs = {kernel.get_home_domain_index(iname) for iname in inames}
148-
params = frozenset().union(*(
153+
params = fset_union(
149154
kernel.domains[dom_idx].get_var_names_not_none(isl.dim_type.param)
150-
for dom_idx in domain_idxs))
155+
for dom_idx in domain_idxs)
151156
return (frozenset(arg
152157
for insn_id in insn_ids
153158
for arg in kernel.id_to_insn[insn_id].read_dependency_names()
@@ -209,6 +214,7 @@ class SubKernelArgInfo(KernelArgInfo):
209214
passed_temporaries: Sequence[str]
210215

211216
@property
217+
@override
212218
def passed_names(self) -> Sequence[str]:
213219
return (list(self.passed_arg_names)
214220
+ list(self.passed_inames)
@@ -221,10 +227,10 @@ def _should_temp_var_be_passed(tv: TemporaryVariable) -> bool:
221227

222228
class _SupportingNameTracker:
223229
def __init__(self, kernel: LoopKernel):
224-
self.kernel = kernel
225-
self.name_to_main_name: dict[str, str] = {}
230+
self.kernel: LoopKernel = kernel
231+
self.name_to_main_name: dict[str, Set[str]] = {}
226232

227-
def add_supporting_names_for(self, name):
233+
def add_supporting_names_for(self, name: str):
228234
var_descr = self.kernel.get_var_descriptor(name)
229235
for supp_name in var_descr.supporting_names():
230236
self.name_to_main_name[supp_name] = (
@@ -234,8 +240,8 @@ def add_supporting_names_for(self, name):
234240
def get_additional_args_and_tvs(
235241
self, already_passed: set[str]
236242
) -> tuple[list[str], list[str]]:
237-
additional_args = []
238-
additional_temporaries = []
243+
additional_args: list[str] = []
244+
additional_temporaries: list[str] = []
239245

240246
for supporting_name in sorted(frozenset(self.name_to_main_name)):
241247
if supporting_name not in already_passed:
@@ -257,7 +263,7 @@ def _process_args_for_arg_info(
257263

258264
args_expected: set[str] = set()
259265

260-
passed_arg_names = []
266+
passed_arg_names: list[str] = []
261267
for arg in kernel.args:
262268
if used_only and not (arg.name in args_read or arg.name in args_written):
263269
continue
@@ -401,8 +407,15 @@ def get_return_from_kernel_mapping(kernel: LoopKernel) -> Mapping[int, int | Non
401407

402408
# {{{ check for write races in accesses
403409

404-
def _check_for_access_races(map_a, insn_a, map_b, insn_b, knl, callables_table,
405-
address_space):
410+
def _check_for_access_races(
411+
map_a: isl.Map,
412+
insn_a: InstructionBase,
413+
map_b: isl.Map,
414+
insn_b: InstructionBase,
415+
knl: LoopKernel,
416+
callables_table: CallablesTable,
417+
address_space: AddressSpace
418+
):
406419
"""
407420
Returns *True* if the execution instances of *insn_a* and *insn_b*, accessing
408421
the same variable via access maps *map_a* and *map_b*, result in an access race.
@@ -439,7 +452,7 @@ def _check_for_access_races(map_a, insn_a, map_b, insn_b, knl, callables_table,
439452
# Step 1.4: Rename the dims with their iname tags i.e. (g.i or l.i)
440453
# Step 1.5: Name the ith output dims as _lp_dim{i}
441454

442-
updated_maps = []
455+
updated_maps: list[isl.Map] = []
443456

444457
for (map_, insn) in [
445458
(map_a, insn_a),
@@ -596,29 +609,30 @@ class AccessMapDescriptor(enum.Enum):
596609
class WriteRaceChecker:
597610
"""Used for checking for overlap between access ranges of instructions."""
598611

599-
def __init__(self, kernel, callables_table):
600-
self.kernel = kernel
601-
self.callables_table = callables_table
612+
def __init__(self, kernel: LoopKernel, callables_table: CallablesTable):
613+
self.kernel: LoopKernel = kernel
614+
self.callables_table: CallablesTable = callables_table
602615

603616
@cached_property
604617
def vars(self):
605618
return (self.kernel.get_written_variables()
606619
| self.kernel.get_read_variables())
607620

608621
@memoize_method
609-
def _get_access_maps(self, insn_id, access_dir):
622+
def _get_access_maps(self, insn_id: InsnId, access_dir: Literal["w", "any"]):
610623
from collections import defaultdict
611624

612625
from loopy.symbolic import BatchedAccessMapMapper
613626

614627
insn = self.kernel.id_to_insn[insn_id]
615628

616-
exprs = list(insn.assignees)
629+
exprs: list[Expression] = list(insn.assignees)
617630
if access_dir == "any":
618631
exprs.append(insn.expression)
619632
exprs.extend(insn.predicates)
620633

621-
access_maps = defaultdict(lambda: AccessMapDescriptor.DOES_NOT_ACCESS)
634+
access_maps: dict[str, AccessMapDescriptor | isl.Map] = defaultdict(
635+
lambda: AccessMapDescriptor.DOES_NOT_ACCESS)
622636

623637
arm = BatchedAccessMapMapper(self.kernel, self.vars, overestimate=True)
624638

@@ -629,11 +643,15 @@ def _get_access_maps(self, insn_id, access_dir):
629643
if arm.bad_subscripts[name]:
630644
access_maps[name] = AccessMapDescriptor.NON_AFFINE_ACCESS
631645
continue
632-
access_maps[name] = arm.access_maps[name][insn.within_inames]
646+
access_maps[name] = not_none(arm.access_maps[name][insn.within_inames])
633647

634648
return access_maps
635649

636-
def _get_access_map_for_var(self, insn_id, access_dir, var_name):
650+
def _get_access_map_for_var(self,
651+
insn_id: InsnId,
652+
access_dir: Literal["w", "any"],
653+
var_name: str
654+
):
637655
assert access_dir in ["w", "any"]
638656

639657
insn = self.kernel.id_to_insn[insn_id]
@@ -642,14 +660,25 @@ def _get_access_map_for_var(self, insn_id, access_dir, var_name):
642660
from loopy.kernel.instruction import MultiAssignmentBase
643661
if not isinstance(insn, MultiAssignmentBase):
644662
if access_dir == "any":
645-
return var_name in insn.dependency_names()
663+
if var_name in insn.dependency_names():
664+
return AccessMapDescriptor.NON_AFFINE_ACCESS
665+
else:
666+
return AccessMapDescriptor.DOES_NOT_ACCESS
646667
else:
647-
return var_name in insn.write_dependency_names()
668+
if var_name in insn.write_dependency_names():
669+
return AccessMapDescriptor.NON_AFFINE_ACCESS
670+
else:
671+
return AccessMapDescriptor.DOES_NOT_ACCESS
648672

649673
return self._get_access_maps(insn_id, access_dir)[var_name]
650674

651-
def do_accesses_result_in_races(self, insn1, insn1_dir, insn2, insn2_dir,
652-
var_name):
675+
def do_accesses_result_in_races(self,
676+
insn1: InsnId,
677+
insn1_dir: Literal["w", "any"],
678+
insn2: InsnId,
679+
insn2_dir: Literal["w", "any"],
680+
var_name: str
681+
):
653682
"""Determine whether the access maps to *var_name* in the two given
654683
instructions result in write races owing to concurrent iname tags. This
655684
determination is made 'conservatively', i.e. if precise information is
@@ -675,7 +704,7 @@ def do_accesses_result_in_races(self, insn1, insn1_dir, insn2, insn2_dir,
675704
return _check_for_access_races(insn1_amap, self.kernel.id_to_insn[insn1],
676705
insn2_amap, self.kernel.id_to_insn[insn2],
677706
self.kernel, self.callables_table,
678-
(self.kernel
707+
cast("AddressSpace", self.kernel
679708
.get_var_descriptor(var_name)
680709
.address_space))
681710

@@ -741,10 +770,7 @@ def separate_loop_nest(
741770
"""
742771
assert all(isinstance(loop_nest, frozenset) for loop_nest in loop_nests)
743772

744-
# annotation to avoid https://github.com/python/mypy/issues/17693
745-
emptyset: InameStrSet = frozenset()
746-
747-
assert inames_to_separate <= reduce(frozenset.union, loop_nests, emptyset)
773+
assert inames_to_separate <= fset_union(loop_nests)
748774

749775
# {{{ sanity check to ensure the loop nest *inames_to_separate* is possible
750776

@@ -760,8 +786,7 @@ def separate_loop_nest(
760786
# }}}
761787

762788
innermost_node = loop_nests[-1]
763-
# separate variable to avoid https://github.com/python/mypy/issues/17694
764-
outerer_loops = reduce(frozenset.union, loop_nests[:-1], emptyset)
789+
outerer_loops = fset_union(loop_nests[:-1])
765790
new_outer_node = inames_to_separate - outerer_loops
766791
new_inner_node = innermost_node - inames_to_separate
767792

@@ -783,7 +808,11 @@ def separate_loop_nest(
783808
return tree, new_outer_node, new_inner_node
784809

785810

786-
def _add_inner_loops(tree, outer_loop_nest, inner_loop_nest):
811+
def _add_inner_loops(
812+
tree: LoopNestTree,
813+
outer_loop_nest: InameStrSet,
814+
inner_loop_nest: InameStrSet
815+
) -> LoopNestTree:
787816
"""
788817
Returns a copy of *tree* that nests *inner_loop_nest* inside *outer_loop_nest*.
789818
"""
@@ -889,7 +918,7 @@ def _update_nesting_constraints(
889918
" schedule kernels with priority dependencies"
890919
" between sibling loop nests")
891920

892-
def _raise_loopy_err(x):
921+
def _raise_loopy_err(x: str):
893922
raise LoopyError(x)
894923

895924
# record strict priorities
@@ -911,9 +940,9 @@ def _raise_loopy_err(x):
911940

912941
assert loop_nest_tree.root == frozenset()
913942

914-
new_tree = Tree.from_root("")
943+
new_tree = Tree[InameStr].from_root("")
915944

916-
old_to_new_parent = {}
945+
old_to_new_parent: dict[InameStrSet, InameStr] = {}
917946

918947
old_to_new_parent[loop_nest_tree.root] = ""
919948

@@ -1054,7 +1083,7 @@ def _get_iname_to_tree_node_id_from_partial_loop_nest_tree(
10541083
10551084
:arg tree: A partial loop nest tree.
10561085
"""
1057-
iname_to_tree_node_id = {}
1086+
iname_to_tree_node_id: dict[InameStr, InameStrSet] = {}
10581087
for node in tree.nodes():
10591088
assert isinstance(node, frozenset)
10601089
for iname in node:

0 commit comments

Comments
 (0)