2525
2626import logging
2727from collections import defaultdict
28+ from dataclasses import dataclass
2829from functools import reduce
29- from typing import TYPE_CHECKING
30+ from typing import TYPE_CHECKING , cast
3031
3132import numpy as np
33+ from typing_extensions import override
3234
3335import islpy as isl
3436from islpy import dim_type
7678if TYPE_CHECKING :
7779 from collections .abc import Mapping , Sequence
7880
81+ import pymbolic .primitives as p
7982 from pymbolic .typing import Expression
8083
8184 from loopy .kernel import LoopKernel
@@ -707,10 +710,12 @@ def subst_func(x):
707710# }}}
708711
709712
710- class _AccessCheckMapper (WalkMapper ):
711- def __init__ (self , kernel , callables_table ):
712- self .kernel = kernel
713- self .callables_table = callables_table
713+ @dataclass
714+ class _AccessCheckMapper (WalkMapper [[isl .Set , str ]]):
715+ kernel : LoopKernel
716+ callables_table : CallablesTable
717+
718+ def __post_init__ (self ) -> None :
714719 super ().__init__ ()
715720
716721 @memoize_method
@@ -719,14 +724,20 @@ def _make_slab(self, space, iname, start, stop):
719724 return make_slab (space , iname , start , stop )
720725
721726 @memoize_method
722- def _get_access_range (self , domain , subscript ):
723- from loopy .symbolic import UnableToDetermineAccessRangeError , get_access_map
727+ def _get_access_range (
728+ self ,
729+ domain : isl .Set ,
730+ subscript : tuple [Expression , ...]
731+ ):
732+ from loopy .diagnostic import UnableToDetermineAccessRangeError
733+ from loopy .symbolic import get_access_map
724734 try :
725735 return get_access_map (domain , subscript ).range ()
726736 except UnableToDetermineAccessRangeError :
727737 return None
728738
729- def map_subscript (self , expr , domain , insn_id ):
739+ @override
740+ def map_subscript (self , expr : p .Subscript , domain : isl .Set , insn_id : str ):
730741 WalkMapper .map_subscript (self , expr , domain , insn_id )
731742
732743 from pymbolic .primitives import Variable
@@ -742,6 +753,7 @@ def map_subscript(self, expr, domain, insn_id):
742753 shape = tv .shape
743754
744755 if shape is not None :
756+ assert isinstance (shape , tuple )
745757 subscript = expr .index
746758
747759 if not isinstance (subscript , tuple ):
@@ -787,7 +799,8 @@ def map_subscript(self, expr, domain, insn_id):
787799 " establish '%s' is a subset of '%s')."
788800 % (expr , insn_id , access_range , shape_domain ))
789801
790- def map_if (self , expr , domain , insn_id ):
802+ @override
803+ def map_if (self , expr : p . If , domain : isl .Set , insn_id : str ):
791804 from loopy .symbolic import condition_to_set
792805 then_set = condition_to_set (domain .space , expr .condition )
793806 if then_set is None :
@@ -800,7 +813,8 @@ def map_if(self, expr, domain, insn_id):
800813 self .rec (expr .then , domain & then_set , insn_id )
801814 self .rec (expr .else_ , domain & else_set , insn_id )
802815
803- def map_call (self , expr , domain , insn_id ):
816+ @override
817+ def map_call (self , expr : p .Call , domain : isl .Set , insn_id : str ):
804818 # perform access checks on the call arguments
805819 super ().map_call (expr , domain , insn_id )
806820
@@ -817,7 +831,9 @@ def map_call(self, expr, domain, insn_id):
817831 and isinstance (self .callables_table [expr .function .name ],
818832 CallableKernel )):
819833
820- subkernel = self .callables_table [expr .function .name ].subkernel
834+ subkernel = cast (
835+ "CallableKernel" ,
836+ self .callables_table [expr .function .name ]).subkernel
821837
822838 # The plan here is to add the constraints coming from the values
823839 # args passed at a call-site as assumptions to the callee. To avoid
@@ -835,8 +851,8 @@ def map_call(self, expr, domain, insn_id):
835851
836852 kw_space = isl .Space .create_from_names (
837853 subkernel .isl_context , set = [],
838- params = ( get_dependencies (tuple (kwargs .values ()))
839- | set ( kwargs .keys ())) )
854+ params = [ * get_dependencies (tuple (kwargs .values ())),
855+ * kwargs .keys ()] )
840856
841857 extra_assumptions = isl .BasicSet .universe (kw_space ).params ()
842858
@@ -894,8 +910,8 @@ def _check_bounds_inner(kernel: LoopKernel, callables_table: CallablesTable) ->
894910 domain , assumptions = isl .align_two (domain , kernel .assumptions )
895911 domain_with_assumptions = domain & assumptions
896912
897- def run_acm (expr ):
898- acm (expr , domain_with_assumptions , insn .id ) # noqa: B023
913+ def run_acm (expr : Expression ):
914+ acm (expr , domain_with_assumptions , not_none ( insn .id ) ) # noqa: B023
899915 return expr
900916
901917 insn .with_transformed_expressions (run_acm )
@@ -1659,7 +1675,10 @@ def _get_sub_array_ref_swept_range(
16591675 from loopy .symbolic import get_access_map
16601676 domain = kernel .get_inames_domain (frozenset ({iname_var .name
16611677 for iname_var in sar .swept_inames }))
1662- return get_access_map (domain , sar .swept_inames , kernel .assumptions ).range ()
1678+ return get_access_map (
1679+ domain .to_set (),
1680+ sar .swept_inames ,
1681+ kernel .assumptions .to_set ()).range ()
16631682
16641683
16651684def _are_sub_array_refs_equivalent (
@@ -1875,7 +1894,7 @@ def pre_codegen_checks(t_unit: TranslationUnit) -> None:
18751894
18761895def check_implemented_domains (
18771896 kernel : LoopKernel ,
1878- implemented_domains : Mapping [str , isl .Set ],
1897+ implemented_domains : Mapping [str , Sequence [ isl .Set ] ],
18791898 code : str | None = None ,
18801899 ) -> bool :
18811900 from islpy import align_two , dim_type
@@ -1942,7 +1961,7 @@ def check_implemented_domains(
19421961 d_minus_i = desired_domain - insn_impl_domain
19431962
19441963 parameter_inames = {
1945- insn_domain .get_dim_name (dim_type .param , i )
1964+ not_none ( insn_domain .get_dim_name (dim_type .param , i ) )
19461965 for i in range (insn_impl_domain .dim (dim_type .param ))}
19471966
19481967 lines = []
0 commit comments