Skip to content

Commit 2274478

Browse files
committed
check for non-affineness instead of subscripts
1 parent 56a22c6 commit 2274478

File tree

2 files changed

+47
-40
lines changed

2 files changed

+47
-40
lines changed

pytato/scalar_expr.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
THE SOFTWARE.
4343
"""
4444

45+
4546
import re
4647
from collections.abc import Iterable, Mapping, Set as AbstractSet
4748
from typing import (
@@ -55,6 +56,8 @@
5556
from typing_extensions import Never, TypeIs, override
5657

5758
import pymbolic.primitives as prim
59+
from loopy.diagnostic import ExpressionToAffineConversionError
60+
from loopy.symbolic import guarded_pwaff_from_expr
5861
from pymbolic import ArithmeticExpression, Bool, Expression, expr_dataclass
5962
from pymbolic.mapper import (
6063
CombineMapper as CombineMapperBase,
@@ -368,4 +371,18 @@ def get_reduction_induction_variables(expr: Expression) -> AbstractSet[str]:
368371
"""
369372
return InductionVariableCollector()(expr)
370373

374+
375+
def is_quasi_affine(expr: Expression) -> bool:
376+
import islpy as isl
377+
space = isl.Space.create_from_names(
378+
isl.DEFAULT_CONTEXT,
379+
set=list(get_dependencies(expr)),
380+
)
381+
try:
382+
guarded_pwaff_from_expr(space, expr)
383+
except ExpressionToAffineConversionError:
384+
return False
385+
return True
386+
387+
371388
# vim: foldmethod=marker

pytato/target/loopy/codegen.py

Lines changed: 30 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@
6868
INT_CLASSES,
6969
ScalarExpression,
7070
TypeCast,
71+
is_quasi_affine,
7172
)
7273
from pytato.tags import (
7374
ForceValueArgTag,
@@ -505,28 +506,36 @@ def map_index_lambda(self, expr: IndexLambda,
505506

506507
loopy_shape = shape_to_scalar_expression(expr.shape, self, state)
507508

508-
# If the scalar expression contains any reductions with bounds expressions
509-
# that index into a binding, need to store the results of those expressions
510-
# as scalar temporaries
511-
subscript_detector = SubscriptDetector()
512-
513509
redn_bounds = {
514510
var_name: redn.bounds[var_name]
515511
for var_name, redn in var_to_reduction.items()}
516512

517-
# FIXME: Forcing storage of expressions containing processed reductions for
518-
# now; attempting to generalize to unmaterialized expressions would require
519-
# handling at least two complications:
520-
# 1) final inames aren't assigned until the expression is stored, so any
521-
# temporary variables defined below would need to be finalized at the
522-
# point of storage, not here
513+
loopy_redn_bounds: Mapping[str, tuple[Expression, Expression]] = {
514+
var_name: cast(
515+
"tuple[Expression, Expression]",
516+
tuple(
517+
self.exprgen_mapper(bound, prstnt_ctx, local_ctx)
518+
for bound in bounds))
519+
for var_name, bounds in redn_bounds.items()}
520+
521+
# If the scalar expression contains any reductions with bounds
522+
# expressions that are non-affine, we need to store the results of those
523+
# expressions as scalar temporaries.
524+
# FIXME: For now, forcing storage of expressions containing such
525+
# processed reductions; attempting to generalize to unmaterialized
526+
# expressions would require handling at least two complications:
527+
# 1) final inames aren't assigned until the expression is stored, so
528+
# any temporary variables defined below would need to be finalized
529+
# at the point of storage, not here
523530
# 2) lp.make_reduction_inames_unique does not rename the temporaries
524-
# created below, so something would need to be done to make them unique
525-
# across all index lambda evaluations.
526-
store_result = expr.tags_of_type(ImplStored) or any(
527-
subscript_detector(bound)
528-
for bounds in redn_bounds.values()
529-
for bound in bounds)
531+
# created below, so something would need to be done to make them
532+
# unique across all index lambda evaluations.
533+
store_result = (
534+
expr.tags_of_type(ImplStored)
535+
or any(
536+
not is_quasi_affine(bound)
537+
for bounds in loopy_redn_bounds.values()
538+
for bound in bounds))
530539

531540
name: str | None = None
532541
inames: tuple[str, ...] | None = None
@@ -550,13 +559,13 @@ def map_index_lambda(self, expr: IndexLambda,
550559
str, tuple[ArithmeticExpression, ArithmeticExpression]] = {}
551560
bound_prefixes = ("l", "u")
552561
for var_name, bounds in redn_bounds.items():
562+
loopy_bounds = loopy_redn_bounds[var_name]
553563
new_bounds_list: list[ArithmeticExpression] = []
554-
for bound_prefix, bound in zip(bound_prefixes, bounds, strict=True):
555-
if subscript_detector(bound):
564+
for bound_prefix, bound, loopy_bound in zip(
565+
bound_prefixes, bounds, loopy_bounds, strict=True):
566+
if not is_quasi_affine(loopy_bound):
556567
unique_name = var_to_reduction_unique_name[var_name]
557568
bound_name = f"{unique_name}_{bound_prefix}bound"
558-
loopy_bound = self.exprgen_mapper(
559-
bound, prstnt_ctx, local_ctx)
560569
bound_result: ImplementedResult = InlinedResult(
561570
loopy_bound, expr.ndim, prstnt_ctx.depends_on)
562571
bound_result = StoredResult(
@@ -793,25 +802,6 @@ def map_call(self, expr: Call, state: CodeGenState) -> None:
793802
}
794803

795804

796-
class SubscriptDetector(scalar_expr.CombineMapper[bool, []]):
797-
"""Returns *True* if a scalar expression contains any subscripts."""
798-
@override
799-
def combine(self, values: Iterable[bool]) -> bool:
800-
return any(values)
801-
802-
@override
803-
def map_algebraic_leaf(self, expr: prim.AlgebraicLeaf) -> bool:
804-
return False
805-
806-
@override
807-
def map_subscript(self, expr: prim.Subscript) -> bool:
808-
return True
809-
810-
@override
811-
def map_constant(self, expr: object) -> bool:
812-
return False
813-
814-
815805
class ReductionCollector(scalar_expr.CombineMapper[frozenset[scalar_expr.Reduce], []]):
816806
"""
817807
Constructs a :class:`frozenset` containing all instances of

0 commit comments

Comments
 (0)