Commit dc11d40
[Pallas TPU] Better error message for lowering
`sp.broadcast_to_p` is a GPU-specific primitive, but it mistakenly appears in TPU lowerings. This PR improves the error message to reflect this.
As an example, currently, users will hit this error when doing:
```
def kernel(x_ref, o_ref):
m, n = 32, 8
x = pl.load(x_ref, (jnp.arange(m, dtype=jnp.int32)[:, None], jnp.arange(n, dtype=jnp.int32)[None]))
o_ref[...] = x
```
PiperOrigin-RevId: 700290975sp.broadcast_to_p
1 parent 231967f commit dc11d40
1 file changed
+12
-0
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
1545 | 1545 | | |
1546 | 1546 | | |
1547 | 1547 | | |
| 1548 | + | |
| 1549 | + | |
| 1550 | + | |
| 1551 | + | |
| 1552 | + | |
| 1553 | + | |
| 1554 | + | |
| 1555 | + | |
| 1556 | + | |
| 1557 | + | |
| 1558 | + | |
| 1559 | + | |
1548 | 1560 | | |
1549 | 1561 | | |
1550 | 1562 | | |
| |||
0 commit comments