Skip to content

Commit 4e17bea

Browse files
committed
[shape_poly] Fix the handling of __pow__ for symbolic dimensions
The code for handling exponentiation was wrong, and there were no tests.
1 parent 7214a3a commit 4e17bea

File tree

2 files changed

+43
-10
lines changed

2 files changed

+43
-10
lines changed

jax/_src/export/shape_poly.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -764,13 +764,23 @@ def __rmul__(self, other):
764764
return _DimExpr._linear_combination(self, other, 0, 0, self.scope)
765765
return _ensure_poly(other, "mul", self.scope).__mul__(self)
766766

767-
def __pow__(self, power, modulo=None):
768-
assert modulo is None
769-
try:
770-
power = int(power)
771-
except:
772-
raise InconclusiveDimensionOperation(f"Symbolic dimension cannot be raised to non-integer power '{self}' ^ '{power}'")
773-
return functools.reduce(op.mul, [self] * power)
767+
def __pow__(self, power: core.DimSize, modulo=None):
768+
if modulo is not None:
769+
raise NotImplementedError("__pow__ modulo not implemented")
770+
if is_symbolic_dim(power):
771+
return power.__rpow__(self) # type: ignore
772+
if power != int(power):
773+
raise ValueError(f"Symbolic dimension cannot be raised to non-integer powers: '{self}' ** '{power}'")
774+
if power >= 0:
775+
return functools.reduce(op.mul, [self] * power, 1)
776+
# We don't support negative powers, because JAX does not allow negative
777+
# powers for integers
778+
raise ValueError(f"Symbolic dimension cannot be raised to negative powers: '{self}' ** '{power}'")
779+
780+
def __rpow__(self, other, modulo=None):
781+
if modulo is not None:
782+
raise NotImplementedError("__rpow__ modulo not implemented")
783+
return self.__jax_array__().__rpow__(other)
774784

775785
def __floordiv__(self, divisor):
776786
if isinstance(divisor, core.Tracer) or not _convertible_to_poly(divisor):

tests/shape_poly_test.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,11 @@ def sampled_assertion(self,
128128
):
129129
"""Checks `assertion(e, fun(*operands))` symbolically and concretely.
130130
131-
For the concrete check, it will same the space of dimension variable
131+
For the concrete check, it will sample the space of dimension variable
132132
assignments for the dimension variables in `e`.
133133
134-
This is useful when `fun` can operate both with polynomials and with
135-
concrete values, and we want to double-check that the behavior is sound.
134+
This is useful when `fun` can operate both with symbolic and with
135+
concrete values, and we want to check that the behavior is sound.
136136
"""
137137
computed_sym = fun(*operands_sym)
138138
assertion_fun = {
@@ -1429,6 +1429,29 @@ def test_non_trivial_dim_expr(self, expr=lambda d: d % -2):
14291429
arg_descriptors=[RandArg((3,), np.int64)],
14301430
polymorphic_shapes=["b"])
14311431

1432+
@jtu.parameterized_filterable(
1433+
# The function `f` will be called with x: f32[b]
1434+
kwargs=[
1435+
dict(testcase_name="cube", f=lambda x: x.shape[0] ** 3),
1436+
dict(testcase_name="zero", f=lambda x: x.shape[0] ** 0),
1437+
dict(testcase_name="rpow", f=lambda x: 2 ** x.shape[0]),
1438+
dict(testcase_name="negative",
1439+
f=lambda x: x.shape[0] ** -2,
1440+
expect_error=(ValueError, "cannot be raised to negative powers")),
1441+
dict(testcase_name="non_integer",
1442+
f=lambda x: x.shape[0] ** 1.5,
1443+
expect_error=(ValueError, "cannot be raised to non-integer powers")),
1444+
dict(testcase_name="sym_pow",
1445+
f=lambda x: x.shape[0] ** x.shape[0]),
1446+
]
1447+
)
1448+
def test_pow(self, f, expect_error: tuple[Exception, str] | None = None):
1449+
check_shape_poly(self,
1450+
f,
1451+
arg_descriptors=[RandArg((3,), np.float32)],
1452+
polymorphic_shapes=["b"],
1453+
expect_error=expect_error)
1454+
14321455
def test_static_shape_result(self):
14331456
"""The result has static shape."""
14341457

0 commit comments

Comments
 (0)