Skip to content

Commit 8b89adc

Browse files
Plumb dot dimension numbers into TPU matmul op.
PiperOrigin-RevId: 694268559
1 parent eb2dd2a commit 8b89adc

File tree

1 file changed

+76
-17
lines changed

1 file changed

+76
-17
lines changed

jax/_src/pallas/mosaic/lowering.py

Lines changed: 76 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,6 +1564,71 @@ def _proxy_fun(val, *, shape, broadcast_dimensions):
15641564
lowering_rules[lax.broadcast_in_dim_p] = _broadcast_in_dim_lowering_rule
15651565

15661566

1567+
def jax_dot_dims_to_tpu_dot_dot_dims(dimension_numbers, lhs_shape, rhs_shape):
1568+
"""Converts a jax dot dimension numbers to a tpu dot dimension numbers.
1569+
1570+
Jax dot dimension numbers are given as a tuple of tuples of sequences of ints
1571+
of the form ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
1572+
rhs_batch_dims)).
1573+
1574+
TPU dot dimension numbers are given as an MLIR definition of the form
1575+
#tpu.dot_dimension_numbers - which can be found in the tpu dilect definition
1576+
# file, tpu.td .
1577+
"""
1578+
(contracting_dims, batch_dims) = dimension_numbers
1579+
lhs_contracting_dims, rhs_contracting_dims = contracting_dims
1580+
lhs_batch_dims, rhs_batch_dims = batch_dims
1581+
1582+
lhs_total_dims = set(range(len(lhs_shape)))
1583+
rhs_total_dims = set(range(len(rhs_shape)))
1584+
1585+
lhs_non_contracting_dims = sorted(
1586+
lhs_total_dims - set(lhs_contracting_dims) - set(lhs_batch_dims)
1587+
)
1588+
rhs_non_contracting_dims = sorted(
1589+
rhs_total_dims - set(rhs_contracting_dims) - set(rhs_batch_dims)
1590+
)
1591+
1592+
# Create output_dim_order
1593+
# Note: we assume that the output dimensions are ordered as batch dims, lhs_non_contracting_dims,
1594+
# rhs_non_contracting_dims - this assumption is safe to make, as it is
1595+
# the same one made in jax's dot_general.
1596+
output_dim_order = []
1597+
1598+
lhs_dim_map = {dim: idx for idx, dim in enumerate(range(len(lhs_shape)))}
1599+
rhs_dim_map = {dim: idx for idx, dim in enumerate(range(len(rhs_shape)))}
1600+
1601+
for dim in lhs_batch_dims:
1602+
output_dim_order.append(0)
1603+
output_dim_order.append(lhs_dim_map[dim])
1604+
1605+
for dim in lhs_non_contracting_dims:
1606+
output_dim_order.append(0)
1607+
output_dim_order.append(lhs_dim_map[dim])
1608+
1609+
for dim in rhs_non_contracting_dims:
1610+
output_dim_order.append(1)
1611+
output_dim_order.append(rhs_dim_map[dim])
1612+
1613+
def format_dims(dims):
1614+
return "[" + ", ".join(str(d) for d in dims) + "]"
1615+
1616+
all_dims = (
1617+
lhs_contracting_dims,
1618+
rhs_contracting_dims,
1619+
lhs_non_contracting_dims,
1620+
rhs_non_contracting_dims,
1621+
output_dim_order,
1622+
lhs_batch_dims,
1623+
rhs_batch_dims,
1624+
)
1625+
tpu_dim_numbers_str = (
1626+
f"#tpu.dot_dimension_numbers<{','.join(map(format_dims, all_dims))}>"
1627+
)
1628+
1629+
return ir.Attribute.parse(tpu_dim_numbers_str)
1630+
1631+
15671632
def _dot_general_lowering_rule(
15681633
ctx: LoweringRuleContext, x, y, dimension_numbers, precision, **_
15691634
):
@@ -1589,7 +1654,7 @@ def _dot_general_lowering_rule(
15891654
raise NotImplementedError(
15901655
f"Only 2D tensors supported in dot; received: {ctx.avals_in}"
15911656
)
1592-
lhs_aval, _ = ctx.avals_in
1657+
lhs_aval, rhs_aval = ctx.avals_in
15931658
# This is really a matrix-vector product. It only looks like matrix-matrix.
15941659
if lhs_dims == (1,) and rhs_dims == (1,) and ctx.avals_in[1].shape[0] == 1:
15951660
if ctx.avals_in[0].shape != ctx.avals_in[1].shape:
@@ -1615,19 +1680,10 @@ def _dot_general_lowering_rule(
16151680
)
16161681
return vector.shape_cast(out_type, red)
16171682

1618-
# TODO(mvoz): Plumb these into dot dimension numbers on the matmul op!
1619-
if lhs_dims == (1,):
1620-
transpose_lhs = False
1621-
elif lhs_dims == (0,):
1622-
transpose_lhs = True
1623-
else:
1624-
raise NotImplementedError
1625-
if rhs_dims == (0,):
1626-
transpose_rhs = False
1627-
elif rhs_dims == (1,):
1628-
transpose_rhs = True
1629-
else:
1630-
raise NotImplementedError
1683+
tpu_dot_dims = jax_dot_dims_to_tpu_dot_dot_dims(
1684+
dimension_numbers, lhs_aval.shape, rhs_aval.shape
1685+
)
1686+
16311687
if precision is not None:
16321688
if precision[0] != precision[1]:
16331689
raise NotImplementedError("Per-operand dot precision unsupported")
@@ -1644,9 +1700,12 @@ def _dot_general_lowering_rule(
16441700
out_type, ir.DenseElementsAttr.get_splat(out_type, val)
16451701
)
16461702
return tpu.matmul(
1647-
out_type, x, y, out_tile,
1648-
transpose_lhs=transpose_lhs, transpose_rhs=transpose_rhs,
1649-
precision=precision_attr
1703+
out_type,
1704+
x,
1705+
y,
1706+
out_tile,
1707+
dimension_numbers=tpu_dot_dims,
1708+
precision=precision_attr,
16501709
)
16511710

16521711

0 commit comments

Comments
 (0)