@@ -135,15 +135,15 @@ def get_cache_key(self):
135135 return (type (self ).__name__ , self .src_expansion , self .tgt_expansion )
136136
137137 @abstractmethod
138- def get_kernel (self ):
138+ @override
139+ def get_kernel (self ) -> lp .TranslationUnit :
139140 pass
140141
141142 def get_optimized_kernel (self ):
142143 # FIXME
143144 knl = self .get_kernel ()
144- knl = lp .split_iname (knl , "itgt_box" , 64 , outer_tag = "g.0" , inner_tag = "l.0" )
145+ return lp .split_iname (knl , "itgt_box" , 64 , outer_tag = "g.0" , inner_tag = "l.0" )
145146
146- return knl
147147
148148# }}}
149149
@@ -259,18 +259,14 @@ def get_kernel(self):
259259 loopy_knl = knl .prepare_loopy_kernel (loopy_knl )
260260
261261 loopy_knl = lp .tag_inames (loopy_knl , "idim*:unr" )
262- loopy_knl = lp .set_options (loopy_knl ,
262+ return lp .set_options (loopy_knl ,
263263 enforce_variable_access_ordered = "no_check" )
264264
265- return loopy_knl
266-
267265 @override
268266 def get_optimized_kernel (self ):
269267 # FIXME
270268 knl = self .get_kernel ()
271- knl = lp .split_iname (knl , "itgt_box" , 64 , outer_tag = "g.0" , inner_tag = "l.0" )
272-
273- return knl
269+ return lp .split_iname (knl , "itgt_box" , 64 , outer_tag = "g.0" , inner_tag = "l.0" )
274270
275271 def __call__ (self , actx : ArrayContext , ** kwargs ):
276272 """
@@ -511,11 +507,9 @@ def get_kernel(self, result_dtype):
511507
512508 def get_optimized_kernel (self , result_dtype ):
513509 knl = self .get_kernel (result_dtype )
514- knl = self .tgt_expansion .m2l_translation .optimize_loopy_kernel (
510+ return self .tgt_expansion .m2l_translation .optimize_loopy_kernel (
515511 knl , self .tgt_expansion , self .src_expansion )
516512
517- return knl
518-
519513 def __call__ (self , actx : ArrayContext , ** kwargs ):
520514 """
521515 :arg src_expansions:
@@ -612,22 +606,18 @@ def get_kernel(self, result_dtype):
612606
613607 loopy_knl = lp .merge ([loopy_knl , translation_classes_data_knl ])
614608 loopy_knl = lp .inline_callable_kernel (loopy_knl , "m2l_data" )
615- loopy_knl = lp .set_options (loopy_knl ,
609+ return lp .set_options (loopy_knl ,
616610 enforce_variable_access_ordered = "no_check" ,
617611 # FIXME: Without this, Loopy spends an eternity checking
618612 # scattered writes to global variables to see whether barriers
619613 # need to be inserted.
620614 disable_global_barriers = True )
621615
622- return loopy_knl
623-
624616 def get_optimized_kernel (self , result_dtype ):
625617 # FIXME
626618 knl = self .get_kernel (result_dtype )
627619 knl = lp .tag_inames (knl , "idim*:unr" )
628- knl = lp .tag_inames (knl , {"itr_class" : "g.0" })
629-
630- return knl
620+ return lp .tag_inames (knl , {"itr_class" : "g.0" })
631621
632622 def __call__ (self , actx : ArrayContext , ** kwargs ):
633623 """
@@ -722,9 +712,7 @@ def get_kernel(self, result_dtype):
722712 loopy_knl = expn .prepare_loopy_kernel (loopy_knl )
723713
724714 loopy_knl = lp .merge ([loopy_knl , single_box_preprocess_knl ])
725- loopy_knl = lp .inline_callable_kernel (loopy_knl , "m2l_preprocess_inner" )
726-
727- return loopy_knl
715+ return lp .inline_callable_kernel (loopy_knl , "m2l_preprocess_inner" )
728716
729717 def get_optimized_kernel (self , result_dtype ):
730718 knl = self .get_kernel (result_dtype )
@@ -822,18 +810,16 @@ def get_kernel(self, result_dtype):
822810 loopy_knl = lp .merge ([loopy_knl , single_box_postprocess_knl ])
823811 loopy_knl = lp .inline_callable_kernel (loopy_knl , "m2l_postprocess_inner" )
824812
825- loopy_knl = lp .set_options (loopy_knl ,
813+ return lp .set_options (loopy_knl ,
826814 enforce_variable_access_ordered = "no_check" )
827- return loopy_knl
828815
829816 def get_optimized_kernel (self , result_dtype ):
830817 knl = self .get_kernel (result_dtype )
831818 knl = lp .tag_inames (knl , "itgt_box:g.0" )
832819 _ , optimizations = self .get_inner_knl_and_optimizations (result_dtype )
833820 for optimization in optimizations :
834821 knl = optimization (knl )
835- knl = lp .add_inames_for_unused_hw_axes (knl )
836- return knl
822+ return lp .add_inames_for_unused_hw_axes (knl )
837823
838824 def __call__ (self , actx : ArrayContext , ** kwargs ):
839825 """
@@ -943,11 +929,9 @@ def get_kernel(self):
943929 loopy_knl = knl .prepare_loopy_kernel (loopy_knl )
944930
945931 loopy_knl = lp .tag_inames (loopy_knl , "idim*:unr" )
946- loopy_knl = lp .set_options (loopy_knl ,
932+ return lp .set_options (loopy_knl ,
947933 enforce_variable_access_ordered = "no_check" )
948934
949- return loopy_knl
950-
951935 def __call__ (self , actx : ArrayContext , ** kwargs ):
952936 """
953937 :arg src_expansions:
@@ -1050,11 +1034,9 @@ def get_kernel(self):
10501034 loopy_knl = knl .prepare_loopy_kernel (loopy_knl )
10511035
10521036 loopy_knl = lp .tag_inames (loopy_knl , "idim*:unr" )
1053- loopy_knl = lp .set_options (loopy_knl ,
1037+ return lp .set_options (loopy_knl ,
10541038 enforce_variable_access_ordered = "no_check" )
10551039
1056- return loopy_knl
1057-
10581040 def __call__ (self , actx : ArrayContext , ** kwargs ):
10591041 """
10601042 :arg src_expansions:
0 commit comments