Skip to content

Commit cb82609

Browse files
chr1sj0nesGoogle-ML-Automation
authored andcommitted
[pallas:triton] Fix reshape lowering with scalar output shape.
PiperOrigin-RevId: 695678909
1 parent 5ec0876 commit cb82609

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

jax/_src/pallas/triton/lowering.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1624,6 +1624,11 @@ def _reshape_lowering_rule(
16241624
return _splat(a, out_aval.shape)
16251625

16261626
ty = ir.RankedTensorType(a.type)
1627+
1628+
# Triton Reshape doesn't support scalar result types (only 0d tensors).
1629+
if not out_aval.shape:
1630+
return _reduce_lowering(jnp.add, ctx, a, axes=tuple(range(ty.rank)))
1631+
16271632
return tt_dialect.reshape(
16281633
ir.RankedTensorType.get([*out_aval.shape], ty.element_type, ty.encoding),
16291634
a,

0 commit comments

Comments
 (0)