Skip to content

Commit 94b09dc

Browse files
yueshengysGoogle-ML-Automation
authored andcommitted
[Pallas/Mosaic TPU] Allow non-leading and non-matching batch dimensions in dot_general.
The constraints on `lhs_batch_dims` and `rhs_batch_dims` for `dot_general` in Pallas/Mosaic on TPU are now relaxed. Batch dimensions do not have to be at the front of the shape, and the dimension indices used for batching on the LHS and RHS can be different. The remaining gap compared to JAX is the lack of support for multiple batch dimensions. PiperOrigin-RevId: 833497703
1 parent ce76d05 commit 94b09dc

File tree

3 files changed

+67
-40
lines changed

3 files changed

+67
-40
lines changed

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1124,15 +1124,8 @@ LogicalResult MatmulOp::verify() {
11241124
const std::optional<int64_t> batch_dim_rhs =
11251125
rhs_batch_dims.empty() ? std::nullopt
11261126
: std::optional<int64_t>(rhs_batch_dims[0]);
1127-
if (batch_dim_lhs != batch_dim_rhs) {
1128-
emitOpError("Not Implemented: batch dims must be equal");
1129-
return failure();
1130-
}
1131-
if (batch_dim_lhs.has_value() && (batch_dim_lhs.value() != 0)) {
1132-
emitOpError("Not Implemented: batch dims pos must be 0");
1133-
return failure();
1134-
}
1135-
// Invariant above enforces only 1 batch dim atm, and that both are eq
1127+
1128+
// Invariant above enforces only 1 batch dim atm.
11361129
std::optional<int64_t> batch_size = std::nullopt;
11371130
if (batch_dim_lhs.has_value()) {
11381131
batch_size = lhs_ty.getShape()[batch_dim_lhs.value()];
@@ -1152,22 +1145,13 @@ LogicalResult MatmulOp::verify() {
11521145
"Illegal: output dim order must have an even number of elements.");
11531146
return failure();
11541147
}
1155-
if (batch_size.has_value()) {
1156-
if (output_dim_order[0] != 0 || output_dim_order[1] != 0) {
1157-
emitOpError(
1158-
"Not implemented: Output with batch size must be the lhs 0 idx for "
1159-
"now.");
1160-
return failure();
1161-
}
1162-
}
11631148

1164-
// Invariants above enforce a single batch idx for now, and that it is in
1165-
// position 0. Future extensions to this will be to:
1166-
// 1. Support multiple batch dims
1167-
// 2. Support batch dims in any position in the output dim order
1149+
// Invariants above enforce a single batch idx for now. Future extension to
1150+
// this will be to support multiple batch dims.
11681151

1169-
// Verify that the output dim order is always in the form of [0, batch_dims,
1170-
// 0, lhs_non_contracting_dims, 1, rhs_non_contracting_dims].
1152+
// Verify that the output dim order is always in the form of [0,
1153+
// lhs_batch_dims, 0, lhs_non_contracting_dims, 1,
1154+
// rhs_non_contracting_dims].
11711155
llvm::SmallVector<int64_t> expected_output_dim_order;
11721156
expected_output_dim_order.reserve(2 * (lhs_batch_dims.size() +
11731157
lhs_non_contracting_dims.size() +
@@ -1187,7 +1171,7 @@ LogicalResult MatmulOp::verify() {
11871171
if (!absl::c_equal(output_dim_order, expected_output_dim_order)) {
11881172
emitOpError(
11891173
"Illegal: output dim order must be in the form of [0, "
1190-
"batch_dims, 0, lhs_non_contracting_dims, 1, "
1174+
"lhs_batch_dims, 0, lhs_non_contracting_dims, 1, "
11911175
"rhs_non_contracting_dims]");
11921176
return failure();
11931177
}

jaxlib/mosaic/dialect/tpu/transforms/canonicalize_mosaic.cc

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -132,24 +132,24 @@ class CanonicalBuilder : public ImplicitLocOpBuilder {
132132
Operation *op_;
133133
};
134134

135-
// Ensures both lhs and rhs have contiguous non-contracting and contracting
136-
// dimensions by inserting transposes if needed. Returns lhs, rhs, and new
137-
// dimension numbers if a transpose was inserted, otherwise returns
138-
// std::nullopt.
135+
// Ensures both lhs and rhs are in form of [batch_dims, non_contracting_dims,
136+
// contracting_dims] or [batch_dims, contracting_dims, non_contracting_dims] by
137+
// inserting transposes if needed. Returns lhs, rhs, and new dimension numbers
138+
// if a transpose was inserted, otherwise returns std::nullopt.
139139
std::optional<std::tuple<TypedValue<VectorType>, TypedValue<VectorType>,
140140
DotDimensionNumbersAttr>>
141141
ensure_matmul_contiguous_dims(
142142
CanonicalBuilder& builder, TypedValue<VectorType> lhs,
143143
TypedValue<VectorType> rhs,
144144
const DotDimensionNumbersAttr& dimension_numbers) {
145-
// Returns a tuple of [new_operand, new_non_contracting_dims,
145+
// Returns a tuple of [new_operand, new_batch_dims, new_non_contracting_dims,
146146
// new_contracting_dims]. new_operand is nullptr if no transpose is inserted.
147147
auto maybe_insert_transpose =
148148
[&](TypedValue<VectorType> operand, ArrayRef<int64_t> batch_dims,
149149
ArrayRef<int64_t> non_contracting_dims,
150150
ArrayRef<int64_t> contracting_dims, bool is_lhs)
151151
-> std::tuple<TypedValue<VectorType>, SmallVector<int64_t>,
152-
SmallVector<int64_t>> {
152+
SmallVector<int64_t>, SmallVector<int64_t>> {
153153
VectorType vty = operand.getType();
154154
auto shape = vty.getShape();
155155
auto rank = shape.size();
@@ -170,7 +170,8 @@ ensure_matmul_contiguous_dims(
170170
contracting_dims.end());
171171
// Already in [B..., NC..., C...].
172172
if (is_identity(perm_BNC)) {
173-
return {nullptr, llvm::to_vector(non_contracting_dims),
173+
return {nullptr, llvm::to_vector(batch_dims),
174+
llvm::to_vector(non_contracting_dims),
174175
llvm::to_vector(contracting_dims)};
175176
}
176177

@@ -183,7 +184,8 @@ ensure_matmul_contiguous_dims(
183184
non_contracting_dims.end());
184185
// Already in [B..., C..., NC...].
185186
if (is_identity(perm_BCN)) {
186-
return {nullptr, llvm::to_vector(non_contracting_dims),
187+
return {nullptr, llvm::to_vector(batch_dims),
188+
llvm::to_vector(non_contracting_dims),
187189
llvm::to_vector(contracting_dims)};
188190
}
189191

@@ -246,18 +248,21 @@ ensure_matmul_contiguous_dims(
246248
};
247249

248250
// Map the dimension indices to the new dimension order.
251+
SmallVector<int64_t> new_b = map_dims(batch_dims);
249252
SmallVector<int64_t> new_c = map_dims(contracting_dims);
250253
SmallVector<int64_t> new_nc = map_dims(non_contracting_dims);
251254

252-
return {new_operand, new_nc, new_c};
255+
return {new_operand, new_b, new_nc, new_c};
253256
};
254257

255-
auto [new_lhs, new_lhs_non_contracting_dims, new_lhs_contracting_dims] =
258+
auto [new_lhs, new_lhs_batch_dims, new_lhs_non_contracting_dims,
259+
new_lhs_contracting_dims] =
256260
maybe_insert_transpose(lhs, dimension_numbers.getLhsBatchDims(),
257261
dimension_numbers.getLhsNonContractingDims(),
258262
dimension_numbers.getLhsContractingDims(),
259263
/*is_lhs=*/true);
260-
auto [new_rhs, new_rhs_non_contracting_dims, new_rhs_contracting_dims] =
264+
auto [new_rhs, new_rhs_batch_dims, new_rhs_non_contracting_dims,
265+
new_rhs_contracting_dims] =
261266
maybe_insert_transpose(rhs, dimension_numbers.getRhsBatchDims(),
262267
dimension_numbers.getRhsNonContractingDims(),
263268
dimension_numbers.getRhsContractingDims(),
@@ -267,10 +272,10 @@ ensure_matmul_contiguous_dims(
267272
}
268273

269274
SmallVector<int64_t> new_output_dim_order;
270-
new_output_dim_order.reserve(2 * (dimension_numbers.getLhsBatchDims().size() +
275+
new_output_dim_order.reserve(2 * (new_lhs_batch_dims.size() +
271276
new_lhs_non_contracting_dims.size() +
272277
new_rhs_non_contracting_dims.size()));
273-
for (int64_t batch_dim : dimension_numbers.getLhsBatchDims()) {
278+
for (int64_t batch_dim : new_lhs_batch_dims) {
274279
new_output_dim_order.push_back(0);
275280
new_output_dim_order.push_back(batch_dim);
276281
}
@@ -286,8 +291,7 @@ ensure_matmul_contiguous_dims(
286291
DotDimensionNumbersAttr new_dimension_numbers = DotDimensionNumbersAttr::get(
287292
builder.getContext(), new_lhs_contracting_dims, new_rhs_contracting_dims,
288293
new_lhs_non_contracting_dims, new_rhs_non_contracting_dims,
289-
new_output_dim_order, dimension_numbers.getLhsBatchDims(),
290-
dimension_numbers.getRhsBatchDims());
294+
new_output_dim_order, new_lhs_batch_dims, new_rhs_batch_dims);
291295

292296
return std::make_tuple(new_lhs ? new_lhs : lhs, new_rhs ? new_rhs : rhs,
293297
new_dimension_numbers);
@@ -562,7 +566,7 @@ FailureOr<Value> canonicalize_matmul(const CanonicalizeContext &ctx,
562566
"dim and rhs must be vector-like [B, K] or [B, 1, K].");
563567
}
564568

565-
auto extsi_sitofp = [&builder, &op](
569+
auto extsi_sitofp = [&builder](
566570
TypedValue<VectorType> element,
567571
std::optional<FloatType> maybe_dest = std::nullopt) {
568572
FloatType dest = maybe_dest.value_or(builder.getF32Type());

tests/pallas/ops_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1498,6 +1498,45 @@ def kernel(x_ref, y_ref, out_ref):
14981498
expected,
14991499
)
15001500

1501+
@parameterized.product(
1502+
shapes_and_dims_numbers=(
1503+
((3, 4, 128), (4, 2, 128), (((2,), (2,)), ((1,), (0,)))),
1504+
((3, 4, 128), (2, 4, 128), (((2,), (2,)), ((1,), (1,)))),
1505+
((3, 4, 256), (2, 3, 256), (((2,), (2,)), ((0,), (1,)))),
1506+
((4, 3, 2, 32), (2, 128, 32, 2), (((3,), (2,)), ((2,), (3,)))),
1507+
),
1508+
)
1509+
def test_dot_general_non_front_batch_dims(self, shapes_and_dims_numbers):
1510+
if jtu.test_device_matches(["gpu"]):
1511+
self.skipTest("TPU only test")
1512+
1513+
if jtu.test_device_matches(["tpu"]) and not jtu.if_cloud_tpu_at_least(
1514+
2025, 11, 21
1515+
):
1516+
self.skipTest("Requires libtpu built after 2025-11-21")
1517+
1518+
x_shape, y_shape, dims_numbers = shapes_and_dims_numbers
1519+
1520+
k1, k2 = random.split(jax.random.key(0))
1521+
x = jax.random.normal(k1, x_shape, dtype=jnp.float32)
1522+
y = jax.random.normal(k2, y_shape, dtype=jnp.float32)
1523+
1524+
# Just infer shape from jax.
1525+
expected = jax.lax.dot_general(x, y, dimension_numbers=dims_numbers)
1526+
1527+
@functools.partial(
1528+
self.pallas_call,
1529+
out_shape=jax.ShapeDtypeStruct(expected.shape, jnp.float32),
1530+
)
1531+
def kernel(x_ref, y_ref, out_ref):
1532+
out_ref[...] = jax.lax.dot_general(
1533+
x_ref[...],
1534+
y_ref[...],
1535+
dimension_numbers=dims_numbers,
1536+
)
1537+
1538+
np.testing.assert_allclose(kernel(x, y), expected, atol=1e-5, rtol=1e-5)
1539+
15011540
@parameterized.product(
15021541
batch_size=(None, 1, 2),
15031542
# dims_numbers is without batch dims

0 commit comments

Comments
 (0)