Skip to content

Commit faa9bfc

Browse files
committed
check for non-affineness instead of subscripts
1 parent 42869bb commit faa9bfc

File tree

1 file changed

+60
-40
lines changed

1 file changed

+60
-40
lines changed

pytato/target/loopy/codegen.py

Lines changed: 60 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from typing_extensions import Never, override
3838

3939
import loopy as lp
40+
import loopy.diagnostic as lp_diagnostic
4041
import loopy.symbolic as lp_symbolic
4142
import pymbolic.primitives as prim
4243
from pymbolic import ArithmeticExpression, var
@@ -505,28 +506,64 @@ def map_index_lambda(self, expr: IndexLambda,
505506

506507
loopy_shape = shape_to_scalar_expression(expr.shape, self, state)
507508

508-
# If the scalar expression contains any reductions with bounds expressions
509-
# that index into a binding, need to store the results of those expressions
510-
# as scalar temporaries
511-
subscript_detector = SubscriptDetector()
512-
513509
redn_bounds = {
514510
var_name: redn.bounds[var_name]
515511
for var_name, redn in var_to_reduction.items()}
516512

517-
# FIXME: Forcing storage of expressions containing processed reductions for
518-
# now; attempting to generalize to unmaterialized expressions would require
519-
# handling at least two complications:
520-
# 1) final inames aren't assigned until the expression is stored, so any
521-
# temporary variables defined below would need to be finalized at the
522-
# point of storage, not here
513+
loopy_redn_bounds: Mapping[str, tuple[Expression, Expression]] = {}
514+
for var_name, bounds in redn_bounds.items():
515+
loopy_bound_list: list[Expression] = []
516+
for bound in bounds:
517+
loopy_bound_list.append(
518+
self.exprgen_mapper(bound, prstnt_ctx, local_ctx))
519+
loopy_redn_bounds[var_name] = cast(
520+
"tuple[Expression, Expression]", tuple(loopy_bound_list))
521+
522+
store_result = expr.tags_of_type(ImplStored)
523+
524+
# FIXME: Did I do this right (the space creation below)? I end up with
525+
# a space that looks like Space("[_pt_data_1] -> { [_0, _r0] }"), where
526+
# _pt_data_1 is the array being indexed into in the reduction bound
527+
528+
set_names = set(
529+
tuple(f"_{idim}" for idim in range(expr.ndim))
530+
+ tuple(loopy_redn_bounds.keys()))
531+
532+
param_names: set[str] = set()
533+
for sdep in map(scalar_expr.get_dependencies, loopy_shape):
534+
param_names |= sdep
535+
536+
for bounds in loopy_redn_bounds.values():
537+
for sdep in map(scalar_expr.get_dependencies, bounds):
538+
param_names |= sdep
539+
540+
param_names -= set_names
541+
542+
space = isl.Space.create_from_names(
543+
isl.DEFAULT_CONTEXT,
544+
set=sorted(set_names),
545+
params=sorted(param_names))
546+
547+
# If the scalar expression contains any reductions with bounds
548+
# expressions that are non-affine, we need to store the results of those
549+
# expressions as scalar temporaries.
550+
# FIXME: For now, forcing storage of expressions containing such
551+
# processed reductions; attempting to generalize to unmaterialized
552+
# expressions would require handling at least two complications:
553+
# 1) final inames aren't assigned until the expression is stored, so
554+
# any temporary variables defined below would need to be finalized
555+
# at the point of storage, not here
523556
# 2) lp.make_reduction_inames_unique does not rename the temporaries
524-
# created below, so something would need to be done to make them unique
525-
# across all index lambda evaluations.
526-
store_result = expr.tags_of_type(ImplStored) or any(
527-
subscript_detector(bound)
528-
for bounds in redn_bounds.values()
529-
for bound in bounds)
557+
# created below, so something would need to be done to make them
558+
# unique across all index lambda evaluations.
559+
560+
for bounds in loopy_redn_bounds.values():
561+
for bound in bounds:
562+
try:
563+
lp_symbolic.guarded_pwaff_from_expr(space, bound)
564+
except lp_diagnostic.ExpressionToAffineConversionError:
565+
store_result = True
566+
break
530567

531568
name: str | None = None
532569
inames: tuple[str, ...] | None = None
@@ -550,13 +587,15 @@ def map_index_lambda(self, expr: IndexLambda,
550587
str, tuple[ArithmeticExpression, ArithmeticExpression]] = {}
551588
bound_prefixes = ("l", "u")
552589
for var_name, bounds in redn_bounds.items():
590+
loopy_bounds = loopy_redn_bounds[var_name]
553591
new_bounds_list: list[ArithmeticExpression] = []
554-
for bound_prefix, bound in zip(bound_prefixes, bounds, strict=True):
555-
if subscript_detector(bound):
592+
for bound_prefix, bound, loopy_bound in zip(
593+
bound_prefixes, bounds, loopy_bounds, strict=True):
594+
try:
595+
lp_symbolic.guarded_pwaff_from_expr(space, loopy_bound)
596+
except lp_diagnostic.ExpressionToAffineConversionError:
556597
unique_name = var_to_reduction_unique_name[var_name]
557598
bound_name = f"{unique_name}_{bound_prefix}bound"
558-
loopy_bound = self.exprgen_mapper(
559-
bound, prstnt_ctx, local_ctx)
560599
bound_result: ImplementedResult = InlinedResult(
561600
loopy_bound, expr.ndim, prstnt_ctx.depends_on)
562601
bound_result = StoredResult(
@@ -793,25 +832,6 @@ def map_call(self, expr: Call, state: CodeGenState) -> None:
793832
}
794833

795834

796-
class SubscriptDetector(scalar_expr.CombineMapper[bool, []]):
797-
"""Returns *True* if a scalar expression contains any subscripts."""
798-
@override
799-
def combine(self, values: Iterable[bool]) -> bool:
800-
return any(values)
801-
802-
@override
803-
def map_algebraic_leaf(self, expr: prim.AlgebraicLeaf) -> bool:
804-
return False
805-
806-
@override
807-
def map_subscript(self, expr: prim.Subscript) -> bool:
808-
return True
809-
810-
@override
811-
def map_constant(self, expr: object) -> bool:
812-
return False
813-
814-
815835
class ReductionCollector(scalar_expr.CombineMapper[frozenset[scalar_expr.Reduce], []]):
816836
"""
817837
Constructs a :class:`frozenset` containing all instances of

0 commit comments

Comments
 (0)