Skip to content

Commit 081eaea

Browse files
committed
Don't use an out-of-line lowering for integer_pow for small powers.
This yields a smaller stablehlo output. Add a fast path for y == 1 and y == -1, which turn out to be reasonably common.
1 parent aefe621 commit 081eaea

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

jax/_src/lax/lax.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2633,24 +2633,24 @@ def _integer_pow(x, *, y):
26332633
def _integer_pow_lowering(ctx, x, *, y):
26342634
# These cases are subsumed by the general case, but it's faster to emit these
26352635
# common cases directly.
2636-
if y == 2:
2636+
if y == 1:
2637+
out = x
2638+
elif y == 2:
26372639
out = hlo.multiply(x, x)
26382640
elif y == 3:
26392641
out = hlo.multiply(hlo.multiply(x, x), x)
2642+
elif y == -1:
2643+
out = hlo.divide(mlir.full_like_aval(ctx, 1, ctx.avals_in[0]), x)
26402644
else:
26412645
lowering = mlir.lower_fun(_integer_pow, multiple_results=False)
2642-
# TODO(b/217551391): emitting an out-of-line call leads to a large
2643-
# expansion when the MLIR is lowered to HLO, because the HLO lowering
2644-
# clones the callee. Consider unconditionally caching when the MLIR->HLO
2645-
# lowering doesn't expand the program.
2646-
lowering = mlir.cache_lowering(lowering)
2647-
out = lowering(ctx, x, y=y)
2646+
if builtins.abs(y) >= 3:
2647+
lowering = mlir.cache_lowering(lowering)
2648+
out, = lowering(ctx, x, y=y)
26482649
if config.sharding_in_types.value:
26492650
aval_out, = ctx.avals_out
26502651
proto = aval_out.sharding._to_xla_hlo_sharding(aval_out.ndim).to_proto()
2651-
out = out[0] if isinstance(out, list) else out
26522652
return [mlir.wrap_with_sharding_op(ctx, out, aval_out, proto)]
2653-
return out if isinstance(out, list) else [out]
2653+
return [out]
26542654

26552655
mlir.register_lowering(integer_pow_p, _integer_pow_lowering)
26562656

0 commit comments

Comments
 (0)