Skip to content

Commit dc11d40

Browse files
ayaka14732Google-ML-Automation
authored andcommitted
[Pallas TPU] Better error message for lowering sp.broadcast_to_p
`sp.broadcast_to_p` is a GPU-specific primitive, but it mistakenly appears in TPU lowerings. This PR improves the error message to reflect this. As an example, currently, users will hit this error when doing: ``` def kernel(x_ref, o_ref): m, n = 32, 8 x = pl.load(x_ref, (jnp.arange(m, dtype=jnp.int32)[:, None], jnp.arange(n, dtype=jnp.int32)[None])) o_ref[...] = x ``` PiperOrigin-RevId: 700290975
1 parent 231967f commit dc11d40

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,6 +1545,18 @@ def _proxy_reduce(arg, *, axes):
15451545
lowering_rules[lax.reduce_or_p] = _reduce_or_lowering_rule
15461546

15471547

1548+
def _broadcast_to_lowering_rule(
1549+
ctx: LoweringRuleContext, x, shape: Sequence[int]
1550+
):
1551+
raise RuntimeError(
1552+
"`broadcast_to` is a Triton-specific primitive. Please consider using"
1553+
" `jnp.broadcast_to` instead."
1554+
)
1555+
1556+
1557+
lowering_rules[state_primitives.broadcast_to_p] = _broadcast_to_lowering_rule
1558+
1559+
15481560
def _broadcast_in_dim_lowering_rule(
15491561
ctx: LoweringRuleContext, val, *, shape, broadcast_dimensions, sharding
15501562
):

0 commit comments

Comments
 (0)