@@ -183,7 +183,10 @@ def get_default_src_tgt_arguments(self):
183183 if self .exclude_self else [])
184184 + gather_loopy_source_arguments (self .source_kernels ))
185185
186- def get_optimized_kernel (self , targets_is_obj_array , sources_is_obj_array ):
186+ def get_optimized_kernel (self , * ,
187+ targets_is_obj_array : bool = False ,
188+ sources_is_obj_array : bool = False ,
189+ ** kwargs : Any ) -> lp .TranslationUnit :
187190 # FIXME
188191 knl = self .get_kernel ()
189192
@@ -194,10 +197,8 @@ def get_optimized_kernel(self, targets_is_obj_array, sources_is_obj_array):
194197
195198 knl = lp .split_iname (knl , "itgt" , 1024 , outer_tag = "g.0" )
196199 knl = self ._allow_redundant_execution_of_knl_scaling (knl )
197- knl = lp .set_options (knl ,
198- enforce_variable_access_ordered = "no_check" )
200+ knl = lp .set_options (knl , enforce_variable_access_ordered = "no_check" )
199201
200- knl = register_optimization_preambles (knl , self .device )
201202 return knl
202203
203204
@@ -475,9 +476,11 @@ class P2PFromCSR(P2PBase):
475476 def default_name (self ):
476477 return "p2p_from_csr"
477478
478- def get_kernel (self ,
479- max_nsources_in_one_box : int , max_ntargets_in_one_box : int , * ,
480- work_items_per_group : int = 32 , is_gpu : bool = False ):
479+ def get_kernel (self , * ,
480+ max_nsources_in_one_box : int = 32 ,
481+ max_ntargets_in_one_box : int = 32 ,
482+ work_items_per_group : int = 32 ,
483+ is_gpu : bool = False , ** kwargs : Any ) -> lp .TranslationUnit :
481484 loopy_insns , _result_names = self .get_loopy_insns_and_result_names ()
482485 arguments = [
483486 * self .get_default_src_tgt_arguments (),
@@ -674,8 +677,10 @@ def get_kernel(self,
674677 "noutputs" : len (self .target_kernels )},
675678 )
676679
677- loopy_knl = lp .add_dtypes (
678- loopy_knl , {"nsources" : np .int32 , "ntargets" : np .int32 })
680+ loopy_knl = lp .add_dtypes (loopy_knl , {
681+ "nsources" : np .dtype (np .int32 ),
682+ "ntargets" : np .dtype (np .int32 ),
683+ })
679684
680685 loopy_knl = lp .tag_inames (loopy_knl , "idim*:unr" )
681686 loopy_knl = lp .tag_inames (loopy_knl , "istrength*:unr" )
@@ -687,19 +692,24 @@ def get_kernel(self,
687692
688693 return loopy_knl
689694
690- def get_optimized_kernel (self ,
691- max_nsources_in_one_box : int ,
692- max_ntargets_in_one_box : int ,
693- strength_dtype : np .dtype [Any ],
694- source_dtype : np .dtype [Any ],
695- local_mem_size : int ,
696- is_gpu : bool ) :
695+ def get_optimized_kernel (self , * ,
696+ max_nsources_in_one_box : int = 32 ,
697+ max_ntargets_in_one_box : int = 32 ,
698+ strength_dtype : np .dtype [Any ] | None = None ,
699+ source_dtype : np .dtype [Any ] | None = None ,
700+ local_mem_size : int = 32 ,
701+ is_gpu : bool = False , ** kwargs ) -> lp . TranslationUnit :
697702 if not is_gpu :
698- knl = self .get_kernel (max_nsources_in_one_box ,
699- max_ntargets_in_one_box , is_gpu = is_gpu )
703+ knl = self .get_kernel (
704+ max_nsources_in_one_box = max_nsources_in_one_box ,
705+ max_ntargets_in_one_box = max_ntargets_in_one_box ,
706+ is_gpu = is_gpu )
700707 knl = lp .split_iname (knl , "itgt_box" , 4 , outer_tag = "g.0" )
701708 knl = self ._allow_redundant_execution_of_knl_scaling (knl )
702709 else :
710+ assert strength_dtype is not None
711+ assert source_dtype is not None
712+
703713 dtype_size = np .dtype (strength_dtype ).alignment
704714 work_items_per_group = min (256 , max_ntargets_in_one_box )
705715 total_local_mem = max_nsources_in_one_box * \
@@ -708,8 +718,9 @@ def get_optimized_kernel(self,
708718 # can be scheduled at the same time for latency hiding
709719 nprefetch = (2 * total_local_mem - 1 ) // local_mem_size + 1
710720
711- knl = self .get_kernel (max_nsources_in_one_box ,
712- max_ntargets_in_one_box ,
721+ knl = self .get_kernel (
722+ max_nsources_in_one_box = max_nsources_in_one_box ,
723+ max_ntargets_in_one_box = max_ntargets_in_one_box ,
713724 work_items_per_group = work_items_per_group ,
714725 is_gpu = is_gpu )
715726 knl = lp .tag_inames (knl , {"itgt_box" : "g.0" , "inner" : "l.0" })
@@ -771,12 +782,8 @@ def get_optimized_kernel(self,
771782 knl = lp .add_inames_to_insn (knl ,
772783 "inner" , "id:init_* or id:*_scaling or id:src_box_insn_*" )
773784 knl = lp .add_inames_to_insn (knl , "itgt_box" , "id:*_scaling" )
774- # knl = lp.set_options(knl, write_code=True)
775-
776- knl = lp .set_options (knl ,
777- enforce_variable_access_ordered = "no_check" )
778785
779- knl = register_optimization_preambles (knl , self . device )
786+ knl = lp . set_options (knl , enforce_variable_access_ordered = "no_check" )
780787 return knl
781788
782789 def __call__ (self , actx : PyOpenCLArrayContext , ** kwargs ):
@@ -786,8 +793,8 @@ def __call__(self, actx: PyOpenCLArrayContext, **kwargs):
786793
787794 is_gpu = not is_cl_cpu (actx )
788795 if is_gpu :
789- source_dtype = kwargs . get ( "sources" ) [0 ].dtype
790- strength_dtype = kwargs . get ( "strength" ) .dtype
796+ source_dtype = kwargs [ "sources" ] [0 ].dtype
797+ strength_dtype = kwargs [ "strength" ] .dtype
791798 else :
792799 # these are unused for not GPU and defeats the caching
793800 # set them to None to keep the caching across dtypes
0 commit comments