Skip to content

Commit 19a51de

Browse files
Merge pull request jax-ml#24897 from hawkinsp:ipow
PiperOrigin-RevId: 696581990
2 parents fcde8aa + 081eaea commit 19a51de

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

jax/_src/lax/lax.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2631,24 +2631,24 @@ def _integer_pow(x, *, y):
26312631
def _integer_pow_lowering(ctx, x, *, y):
26322632
# These cases are subsumed by the general case, but it's faster to emit these
26332633
# common cases directly.
2634-
if y == 2:
2634+
if y == 1:
2635+
out = x
2636+
elif y == 2:
26352637
out = hlo.multiply(x, x)
26362638
elif y == 3:
26372639
out = hlo.multiply(hlo.multiply(x, x), x)
2640+
elif y == -1:
2641+
out = hlo.divide(mlir.full_like_aval(ctx, 1, ctx.avals_in[0]), x)
26382642
else:
26392643
lowering = mlir.lower_fun(_integer_pow, multiple_results=False)
2640-
# TODO(b/217551391): emitting an out-of-line call leads to a large
2641-
# expansion when the MLIR is lowered to HLO, because the HLO lowering
2642-
# clones the callee. Consider unconditionally caching when the MLIR->HLO
2643-
# lowering doesn't expand the program.
2644-
lowering = mlir.cache_lowering(lowering)
2645-
out = lowering(ctx, x, y=y)
2644+
if builtins.abs(y) >= 3:
2645+
lowering = mlir.cache_lowering(lowering)
2646+
out, = lowering(ctx, x, y=y)
26462647
if config.sharding_in_types.value:
26472648
aval_out, = ctx.avals_out
26482649
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
2649-
out = out[0] if isinstance(out, list) else out
26502650
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
2651-
return out if isinstance(out, list) else [out]
2651+
return [out]
26522652

26532653
mlir.register_lowering(integer_pow_p, _integer_pow_lowering)
26542654

0 commit comments

Comments
 (0)