3737from typing_extensions import Never , override
3838
3939import loopy as lp
40+ import loopy .diagnostic as lp_diagnostic
4041import loopy .symbolic as lp_symbolic
4142import pymbolic .primitives as prim
4243from pymbolic import ArithmeticExpression , var
@@ -505,28 +506,64 @@ def map_index_lambda(self, expr: IndexLambda,
505506
506507 loopy_shape = shape_to_scalar_expression (expr .shape , self , state )
507508
508- # If the scalar expression contains any reductions with bounds expressions
509- # that index into a binding, need to store the results of those expressions
510- # as scalar temporaries
511- subscript_detector = SubscriptDetector ()
512-
513509 redn_bounds = {
514510 var_name : redn .bounds [var_name ]
515511 for var_name , redn in var_to_reduction .items ()}
516512
517- # FIXME: Forcing storage of expressions containing processed reductions for
518- # now; attempting to generalize to unmaterialized expressions would require
519- # handling at least two complications:
520- # 1) final inames aren't assigned until the expression is stored, so any
521- # temporary variables defined below would need to be finalized at the
522- # point of storage, not here
513+ loopy_redn_bounds : Mapping [str , tuple [Expression , Expression ]] = {}
514+ for var_name , bounds in redn_bounds .items ():
515+ loopy_bound_list : list [Expression ] = []
516+ for bound in bounds :
517+ loopy_bound_list .append (
518+ self .exprgen_mapper (bound , prstnt_ctx , local_ctx ))
519+ loopy_redn_bounds [var_name ] = cast (
520+ "tuple[Expression, Expression]" , tuple (loopy_bound_list ))
521+
522+ store_result = expr .tags_of_type (ImplStored )
523+
524+ # FIXME: Did I do this right (the space creation below)? I end up with
525+ # a space that looks like Space("[_pt_data_1] -> { [_0, _r0] }"), where
526+ # _pt_data_1 is the array being indexed into in the reduction bound
527+
528+ set_names = set (
529+ tuple (f"_{ idim } " for idim in range (expr .ndim ))
530+ + tuple (loopy_redn_bounds .keys ()))
531+
532+ param_names : set [str ] = set ()
533+ for sdep in map (scalar_expr .get_dependencies , loopy_shape ):
534+ param_names |= sdep
535+
536+ for bounds in loopy_redn_bounds .values ():
537+ for sdep in map (scalar_expr .get_dependencies , bounds ):
538+ param_names |= sdep
539+
540+ param_names -= set_names
541+
542+ space = isl .Space .create_from_names (
543+ isl .DEFAULT_CONTEXT ,
544+ set = sorted (set_names ),
545+ params = sorted (param_names ))
546+
547+ # If the scalar expression contains any reductions with bounds
548+ # expressions that are non-affine, we need to store the results of those
549+ # expressions as scalar temporaries.
550+ # FIXME: For now, forcing storage of expressions containing such
551+ # processed reductions; attempting to generalize to unmaterialized
552+ # expressions would require handling at least two complications:
553+ # 1) final inames aren't assigned until the expression is stored, so
554+ # any temporary variables defined below would need to be finalized
555+ # at the point of storage, not here
523556 # 2) lp.make_reduction_inames_unique does not rename the temporaries
524- # created below, so something would need to be done to make them unique
525- # across all index lambda evaluations.
526- store_result = expr .tags_of_type (ImplStored ) or any (
527- subscript_detector (bound )
528- for bounds in redn_bounds .values ()
529- for bound in bounds )
557+ # created below, so something would need to be done to make them
558+ # unique across all index lambda evaluations.
559+
560+ for bounds in loopy_redn_bounds .values ():
561+ for bound in bounds :
562+ try :
563+ lp_symbolic .guarded_pwaff_from_expr (space , bound )
564+ except lp_diagnostic .ExpressionToAffineConversionError :
565+ store_result = True
566+ break
530567
531568 name : str | None = None
532569 inames : tuple [str , ...] | None = None
@@ -550,13 +587,15 @@ def map_index_lambda(self, expr: IndexLambda,
550587 str , tuple [ArithmeticExpression , ArithmeticExpression ]] = {}
551588 bound_prefixes = ("l" , "u" )
552589 for var_name , bounds in redn_bounds .items ():
590+ loopy_bounds = loopy_redn_bounds [var_name ]
553591 new_bounds_list : list [ArithmeticExpression ] = []
554- for bound_prefix , bound in zip (bound_prefixes , bounds , strict = True ):
555- if subscript_detector (bound ):
592+ for bound_prefix , bound , loopy_bound in zip (
593+ bound_prefixes , bounds , loopy_bounds , strict = True ):
594+ try :
595+ lp_symbolic .guarded_pwaff_from_expr (space , loopy_bound )
596+ except lp_diagnostic .ExpressionToAffineConversionError :
556597 unique_name = var_to_reduction_unique_name [var_name ]
557598 bound_name = f"{ unique_name } _{ bound_prefix } bound"
558- loopy_bound = self .exprgen_mapper (
559- bound , prstnt_ctx , local_ctx )
560599 bound_result : ImplementedResult = InlinedResult (
561600 loopy_bound , expr .ndim , prstnt_ctx .depends_on )
562601 bound_result = StoredResult (
@@ -793,25 +832,6 @@ def map_call(self, expr: Call, state: CodeGenState) -> None:
793832}
794833
795834
796- class SubscriptDetector (scalar_expr .CombineMapper [bool , []]):
797- """Returns *True* if a scalar expression contains any subscripts."""
798- @override
799- def combine (self , values : Iterable [bool ]) -> bool :
800- return any (values )
801-
802- @override
803- def map_algebraic_leaf (self , expr : prim .AlgebraicLeaf ) -> bool :
804- return False
805-
806- @override
807- def map_subscript (self , expr : prim .Subscript ) -> bool :
808- return True
809-
810- @override
811- def map_constant (self , expr : object ) -> bool :
812- return False
813-
814-
815835class ReductionCollector (scalar_expr .CombineMapper [frozenset [scalar_expr .Reduce ], []]):
816836 """
817837 Constructs a :class:`frozenset` containing all instances of
0 commit comments