@@ -1564,6 +1564,71 @@ def _proxy_fun(val, *, shape, broadcast_dimensions):
15641564lowering_rules [lax .broadcast_in_dim_p ] = _broadcast_in_dim_lowering_rule
15651565
15661566
1567+ def jax_dot_dims_to_tpu_dot_dot_dims (dimension_numbers , lhs_shape , rhs_shape ):
1568+ """Converts a jax dot dimension numbers to a tpu dot dimension numbers.
1569+
1570+ Jax dot dimension numbers are given as a tuple of tuples of sequences of ints
1571+ of the form ((lhs_contracting_dims, rhs_contracting_dims), (lhs_batch_dims,
1572+ rhs_batch_dims)).
1573+
1574+ TPU dot dimension numbers are given as an MLIR definition of the form
1575+ #tpu.dot_dimension_numbers - which can be found in the tpu dilect definition
1576+ # file, tpu.td .
1577+ """
1578+ (contracting_dims , batch_dims ) = dimension_numbers
1579+ lhs_contracting_dims , rhs_contracting_dims = contracting_dims
1580+ lhs_batch_dims , rhs_batch_dims = batch_dims
1581+
1582+ lhs_total_dims = set (range (len (lhs_shape )))
1583+ rhs_total_dims = set (range (len (rhs_shape )))
1584+
1585+ lhs_non_contracting_dims = sorted (
1586+ lhs_total_dims - set (lhs_contracting_dims ) - set (lhs_batch_dims )
1587+ )
1588+ rhs_non_contracting_dims = sorted (
1589+ rhs_total_dims - set (rhs_contracting_dims ) - set (rhs_batch_dims )
1590+ )
1591+
1592+ # Create output_dim_order
1593+ # Note: we assume that the output dimensions are ordered as batch dims, lhs_non_contracting_dims,
1594+ # rhs_non_contracting_dims - this assumption is safe to make, as it is
1595+ # the same one made in jax's dot_general.
1596+ output_dim_order = []
1597+
1598+ lhs_dim_map = {dim : idx for idx , dim in enumerate (range (len (lhs_shape )))}
1599+ rhs_dim_map = {dim : idx for idx , dim in enumerate (range (len (rhs_shape )))}
1600+
1601+ for dim in lhs_batch_dims :
1602+ output_dim_order .append (0 )
1603+ output_dim_order .append (lhs_dim_map [dim ])
1604+
1605+ for dim in lhs_non_contracting_dims :
1606+ output_dim_order .append (0 )
1607+ output_dim_order .append (lhs_dim_map [dim ])
1608+
1609+ for dim in rhs_non_contracting_dims :
1610+ output_dim_order .append (1 )
1611+ output_dim_order .append (rhs_dim_map [dim ])
1612+
1613+ def format_dims (dims ):
1614+ return "[" + ", " .join (str (d ) for d in dims ) + "]"
1615+
1616+ all_dims = (
1617+ lhs_contracting_dims ,
1618+ rhs_contracting_dims ,
1619+ lhs_non_contracting_dims ,
1620+ rhs_non_contracting_dims ,
1621+ output_dim_order ,
1622+ lhs_batch_dims ,
1623+ rhs_batch_dims ,
1624+ )
1625+ tpu_dim_numbers_str = (
1626+ f"#tpu.dot_dimension_numbers<{ ',' .join (map (format_dims , all_dims ))} >"
1627+ )
1628+
1629+ return ir .Attribute .parse (tpu_dim_numbers_str )
1630+
1631+
15671632def _dot_general_lowering_rule (
15681633 ctx : LoweringRuleContext , x , y , dimension_numbers , precision , ** _
15691634):
@@ -1589,7 +1654,7 @@ def _dot_general_lowering_rule(
15891654 raise NotImplementedError (
15901655 f"Only 2D tensors supported in dot; received: { ctx .avals_in } "
15911656 )
1592- lhs_aval , _ = ctx .avals_in
1657+ lhs_aval , rhs_aval = ctx .avals_in
15931658 # This is really a matrix-vector product. It only looks like matrix-matrix.
15941659 if lhs_dims == (1 ,) and rhs_dims == (1 ,) and ctx .avals_in [1 ].shape [0 ] == 1 :
15951660 if ctx .avals_in [0 ].shape != ctx .avals_in [1 ].shape :
@@ -1615,19 +1680,10 @@ def _dot_general_lowering_rule(
16151680 )
16161681 return vector .shape_cast (out_type , red )
16171682
1618- # TODO(mvoz): Plumb these into dot dimension numbers on the matmul op!
1619- if lhs_dims == (1 ,):
1620- transpose_lhs = False
1621- elif lhs_dims == (0 ,):
1622- transpose_lhs = True
1623- else :
1624- raise NotImplementedError
1625- if rhs_dims == (0 ,):
1626- transpose_rhs = False
1627- elif rhs_dims == (1 ,):
1628- transpose_rhs = True
1629- else :
1630- raise NotImplementedError
1683+ tpu_dot_dims = jax_dot_dims_to_tpu_dot_dot_dims (
1684+ dimension_numbers , lhs_aval .shape , rhs_aval .shape
1685+ )
1686+
16311687 if precision is not None :
16321688 if precision [0 ] != precision [1 ]:
16331689 raise NotImplementedError ("Per-operand dot precision unsupported" )
@@ -1644,9 +1700,12 @@ def _dot_general_lowering_rule(
16441700 out_type , ir .DenseElementsAttr .get_splat (out_type , val )
16451701 )
16461702 return tpu .matmul (
1647- out_type , x , y , out_tile ,
1648- transpose_lhs = transpose_lhs , transpose_rhs = transpose_rhs ,
1649- precision = precision_attr
1703+ out_type ,
1704+ x ,
1705+ y ,
1706+ out_tile ,
1707+ dimension_numbers = tpu_dot_dims ,
1708+ precision = precision_attr ,
16501709 )
16511710
16521711
0 commit comments