Skip to content

Commit 29bfd00

Browse files
bythew3iGoogle-ML-Automation
authored andcommitted
[Pallas TPU] Fix preferred_element_type propagation in dot_general with const
PiperOrigin-RevId: 735903687
1 parent 13eb8d3 commit 29bfd00

File tree

2 files changed

+53
-2
lines changed

2 files changed

+53
-2
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1853,7 +1853,13 @@ def format_dims(dims):
18531853

18541854

18551855
def _dot_general_lowering_rule(
1856-
ctx: LoweringRuleContext, x, y, dimension_numbers, precision, **_
1856+
ctx: LoweringRuleContext,
1857+
x,
1858+
y,
1859+
dimension_numbers,
1860+
precision,
1861+
preferred_element_type,
1862+
**_,
18571863
):
18581864
(lhs_dims, rhs_dims), _ = dimension_numbers
18591865
(aval_out,) = ctx.avals_out
@@ -1894,10 +1900,34 @@ def _dot_general_lowering_rule(
18941900
x = vector.broadcast(bcast_shape, x)
18951901
if ctx.avals_in[1].shape != bcast_shape:
18961902
y = vector.broadcast(bcast_shape, y)
1903+
red_dtype = (
1904+
preferred_element_type if preferred_element_type else lhs_aval.dtype
1905+
)
18971906
red_type = aval_to_ir_type(
18981907
ctx.lowering_context.dynamic_shape_replacement_fn,
1899-
lhs_aval.update(shape=(lhs_aval.shape[0],)),
1908+
lhs_aval.update(shape=(lhs_aval.shape[0],), dtype=red_dtype),
19001909
)
1910+
1911+
if lhs_aval.dtype != red_dtype:
1912+
lhs_type = aval_to_ir_type(
1913+
ctx.lowering_context.dynamic_shape_replacement_fn,
1914+
lhs_aval.update(shape=lhs_aval.shape, dtype=red_dtype),
1915+
)
1916+
if red_dtype == jnp.float32:
1917+
x = arith.extf(lhs_type, x)
1918+
else:
1919+
raise NotImplementedError(f"Unsupported {preferred_element_type=}")
1920+
1921+
if rhs_aval.dtype != red_dtype:
1922+
rhs_type = aval_to_ir_type(
1923+
ctx.lowering_context.dynamic_shape_replacement_fn,
1924+
rhs_aval.update(shape=rhs_aval.shape, dtype=red_dtype),
1925+
)
1926+
if red_dtype == jnp.float32:
1927+
y = arith.extf(rhs_type, y)
1928+
else:
1929+
raise NotImplementedError(f"Unsupported {preferred_element_type=}")
1930+
19011931
acc = arith.ConstantOp(
19021932
red_type, ir.DenseElementsAttr.get_splat(red_type, val)
19031933
)

tests/pallas/tpu_ops_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,27 @@ def kernel(x, out):
470470
expected = lax.select(concated_mask, concated_x, jnp.zeros_like(concated_x))
471471
np.testing.assert_array_equal(out, expected)
472472

473+
def test_reduce_with_const(self):
474+
m = 1
475+
d = 1024
476+
x = jnp.ones((m, d), jnp.bfloat16)
477+
478+
def dot(x, y):
479+
return jax.lax.dot_general(
480+
x,
481+
y,
482+
(((1,), (1,)), ((), ())),
483+
preferred_element_type=jnp.float32,
484+
)
485+
486+
def kernel(x, out):
487+
out[:] = dot(x[:], jnp.ones((1, d), jnp.bfloat16))
488+
489+
run = pl.pallas_call(kernel, jax.ShapeDtypeStruct((m, 1), jnp.float32))
490+
output = run(x)
491+
expected = dot(x[:], jnp.ones((1, d), jnp.bfloat16))
492+
np.testing.assert_array_equal(output, expected)
493+
473494

474495
class OpsInterpretTest(OpsTest):
475496
INTERPRET = True

0 commit comments

Comments
 (0)