6868 INT_CLASSES ,
6969 ScalarExpression ,
7070 TypeCast ,
71+ is_quasi_affine ,
7172)
7273from pytato .tags import (
7374 ForceValueArgTag ,
@@ -505,28 +506,36 @@ def map_index_lambda(self, expr: IndexLambda,
505506
506507 loopy_shape = shape_to_scalar_expression (expr .shape , self , state )
507508
508- # If the scalar expression contains any reductions with bounds expressions
509- # that index into a binding, need to store the results of those expressions
510- # as scalar temporaries
511- subscript_detector = SubscriptDetector ()
512-
513509 redn_bounds = {
514510 var_name : redn .bounds [var_name ]
515511 for var_name , redn in var_to_reduction .items ()}
516512
517- # FIXME: Forcing storage of expressions containing processed reductions for
518- # now; attempting to generalize to unmaterialized expressions would require
519- # handling at least two complications:
520- # 1) final inames aren't assigned until the expression is stored, so any
521- # temporary variables defined below would need to be finalized at the
522- # point of storage, not here
513+ loopy_redn_bounds : Mapping [str , tuple [Expression , Expression ]] = {
514+ var_name : cast (
515+ "tuple[Expression, Expression]" ,
516+ tuple (
517+ self .exprgen_mapper (bound , prstnt_ctx , local_ctx )
518+ for bound in bounds ))
519+ for var_name , bounds in redn_bounds .items ()}
520+
521+ # If the scalar expression contains any reductions with bounds
522+ # expressions that are non-affine, we need to store the results of those
523+ # expressions as scalar temporaries.
524+ # FIXME: For now, forcing storage of expressions containing such
525+ # processed reductions; attempting to generalize to unmaterialized
526+ # expressions would require handling at least two complications:
527+ # 1) final inames aren't assigned until the expression is stored, so
528+ # any temporary variables defined below would need to be finalized
529+ # at the point of storage, not here
523530 # 2) lp.make_reduction_inames_unique does not rename the temporaries
524- # created below, so something would need to be done to make them unique
525- # across all index lambda evaluations.
526- store_result = expr .tags_of_type (ImplStored ) or any (
527- subscript_detector (bound )
528- for bounds in redn_bounds .values ()
529- for bound in bounds )
531+ # created below, so something would need to be done to make them
532+ # unique across all index lambda evaluations.
533+ store_result = (
534+ expr .tags_of_type (ImplStored )
535+ or any (
536+ not is_quasi_affine (bound )
537+ for bounds in loopy_redn_bounds .values ()
538+ for bound in bounds ))
530539
531540 name : str | None = None
532541 inames : tuple [str , ...] | None = None
@@ -550,13 +559,13 @@ def map_index_lambda(self, expr: IndexLambda,
550559 str , tuple [ArithmeticExpression , ArithmeticExpression ]] = {}
551560 bound_prefixes = ("l" , "u" )
552561 for var_name , bounds in redn_bounds .items ():
562+ loopy_bounds = loopy_redn_bounds [var_name ]
553563 new_bounds_list : list [ArithmeticExpression ] = []
554- for bound_prefix , bound in zip (bound_prefixes , bounds , strict = True ):
555- if subscript_detector (bound ):
564+ for bound_prefix , bound , loopy_bound in zip (
565+ bound_prefixes , bounds , loopy_bounds , strict = True ):
566+ if not is_quasi_affine (loopy_bound ):
556567 unique_name = var_to_reduction_unique_name [var_name ]
557568 bound_name = f"{ unique_name } _{ bound_prefix } bound"
558- loopy_bound = self .exprgen_mapper (
559- bound , prstnt_ctx , local_ctx )
560569 bound_result : ImplementedResult = InlinedResult (
561570 loopy_bound , expr .ndim , prstnt_ctx .depends_on )
562571 bound_result = StoredResult (
@@ -793,25 +802,6 @@ def map_call(self, expr: Call, state: CodeGenState) -> None:
793802}
794803
795804
796- class SubscriptDetector (scalar_expr .CombineMapper [bool , []]):
797- """Returns *True* if a scalar expression contains any subscripts."""
798- @override
799- def combine (self , values : Iterable [bool ]) -> bool :
800- return any (values )
801-
802- @override
803- def map_algebraic_leaf (self , expr : prim .AlgebraicLeaf ) -> bool :
804- return False
805-
806- @override
807- def map_subscript (self , expr : prim .Subscript ) -> bool :
808- return True
809-
810- @override
811- def map_constant (self , expr : object ) -> bool :
812- return False
813-
814-
815805class ReductionCollector (scalar_expr .CombineMapper [frozenset [scalar_expr .Reduce ], []]):
816806 """
817807 Constructs a :class:`frozenset` containing all instances of
0 commit comments