Skip to content

Commit 38d062d

Browse files
bythew3iGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Support dynamic DMA and ref slice on the 2nd minor when memref is untiled
* Generalize any untiled memref to have tiling (packing, 128) * Support dynamic index on 2nd minor. * Support dynamic shape on 2nd minor. PiperOrigin-RevId: 695516124
1 parent 6892e62 commit 38d062d

File tree

7 files changed

+165
-25
lines changed

7 files changed

+165
-25
lines changed

jaxlib/mosaic/dialect/tpu/transforms/apply_vector_layout.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2996,7 +2996,7 @@ LogicalResult vector_load_rule(RewriteContext &ctx, Operation &op,
29962996
// TODO(b/295393167): need to support strided load for bitwidth < 32.
29972997
} else if (layout_out.bitwidth() == 32 &&
29982998
canReinterpretToUntiledMemref(
2999-
memref_ty, ctx.target_shape,
2999+
load_op.getBase(), ctx.target_shape,
30003000
/*allow_minormost_padding=*/true)) {
30013001
// In this case, if the memref can be reinterpreted to untiled, it is
30023002
// valid to use any tiling for output. But using native tiling can save us
@@ -4204,7 +4204,7 @@ LogicalResult vector_store_rule(RewriteContext &ctx, Operation &op,
42044204
// We accept padding in the minormost dim, because
42054205
// apply_vector_layout will properly mask stores。
42064206
canReinterpretToUntiledMemref(
4207-
memref_ty, ctx.target_shape,
4207+
store_op.getBase(), ctx.target_shape,
42084208
/*allow_minormost_padding=*/true)) {
42094209
// In this case, if the memref can be reinterpreted to untiled, it is
42104210
// valid to use any tiling for to_store. But using native tiling can save

jaxlib/mosaic/dialect/tpu/transforms/infer_memref_layout.cc

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,16 @@ FailureOr<TiledLayoutAttr> inferLayout(MemRefType memref_ty,
8787
int64_t leading_tile_rows = 0) {
8888
if (auto tiled_layout_attr =
8989
dyn_cast<TiledLayoutAttr>(memref_ty.getLayout())) {
90+
if (leading_tile_rows > 0 && !tiled_layout_attr.getTiles().empty() &&
91+
tiled_layout_attr.getTiles().front().dimensions().size() == 2 &&
92+
tiled_layout_attr.getTiles().front().dimensions()[0] !=
93+
leading_tile_rows) {
94+
return emitError(UnknownLoc::get(memref_ty.getContext()),
95+
"Trying to infer memref layout with sublane tiling ")
96+
<< leading_tile_rows
97+
<< ", but the memref already has sublane tiling "
98+
<< tiled_layout_attr.getTiles().front().dimensions()[0];
99+
}
90100
return tiled_layout_attr;
91101
}
92102
if (auto affine_map_attr = dyn_cast<AffineMapAttr>(memref_ty.getLayout())) {
@@ -226,13 +236,25 @@ LogicalResult inferOp(Operation &op, const int hardware_generation,
226236
if (auto alloca_op = dyn_cast<memref::AllocaOp>(op)) {
227237
TypedValue<MemRefType> arg = alloca_op.getResult();
228238
const MemRefType memref_ty = alloca_op.getResult().getType();
229-
FAILUREOR_ASSIGN_OR_RETURN(const MemRefType new_memref_ty,
230-
inferMemref(memref_ty, hardware_generation,
231-
target_shape, tpu_tiling_flags));
239+
// If the memref can be reinterpreted to untiled, force to use tiling
240+
// {1, target.lane_count} for 32 bit.
241+
int64_t leading_tile_rows = 0;
242+
// TODO(b/375038685): generalize untiled memref with packed type which
243+
// needs to update load/store rules.
244+
if (memref_ty.getElementTypeBitWidth() == 32 && memref_ty.getRank() > 1 &&
245+
*(memref_ty.getShape().end() - 1) <= target_shape[1]) {
246+
leading_tile_rows = 1;
247+
}
248+
FAILUREOR_ASSIGN_OR_RETURN(
249+
const MemRefType new_memref_ty,
250+
inferMemref(memref_ty, hardware_generation, target_shape,
251+
tpu_tiling_flags, leading_tile_rows));
232252
alloca_op.getResult().setType(new_memref_ty);
233253
if (memref_ty != new_memref_ty) {
234254
OpBuilder builder(alloca_op->getContext());
235255
builder.setInsertionPointAfter(alloca_op);
256+
// TODO(b/376130272): add a canonicalizer for EraseLayoutOp so that if we
257+
// have erase(erase(x)) then we rewrite it to erase(x).
236258
auto erase_op = builder.create<tpu::EraseLayoutOp>(
237259
arg.getLoc(),
238260
MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(),
@@ -296,22 +318,56 @@ LogicalResult inferFunc(func::FuncOp f, const int hardware_generation,
296318
}
297319

298320
FAILUREOR_ASSIGN_OR_RETURN(
299-
const MemRefType new_memref_ty,
321+
MemRefType new_memref_ty,
300322
inferMemref(memref_ty, hardware_generation, target_shape,
301323
tpu_tiling_flags, leading_tile_rows));
302324
arg.setType(new_memref_ty);
303325
new_arg_types.push_back(arg.getType());
304326
if (memref_ty != new_memref_ty) {
327+
Value val = arg;
328+
Operation * arg_use_op = nullptr;
329+
// If the arg memref can be reinterpreted to untiled, we can insert
330+
// ReinterpretCastOp to use tiling {packing, target.lane_count} before
331+
// EraseLayoutOp for only the arg memrefs and expect the rest memref
332+
// layout inference is based on the casted layout automatically. This
333+
// would help lift many restrictions in alignment check when consuming
334+
// this memref.
335+
if (canReinterpretToUntiledMemref(cast<TypedValue<MemRefType>>(val),
336+
target_shape,
337+
/*allow_minormost_padding=*/true) &&
338+
// TODO(b/375038685): generalize untiled memref with packed type which
339+
// needs to update load/store rules.
340+
new_memref_ty.getElementTypeBitWidth() == 32) {
341+
auto tiled_layout =
342+
cast<tpu::TiledLayoutAttr>(new_memref_ty.getLayout());
343+
SmallVector<xla::Tile> tiles(tiled_layout.getTiles());
344+
SmallVector<int64_t> new_tile_strides(tiled_layout.getTileStrides());
345+
for (int i = 0; i < new_tile_strides.size() - 2; ++i) {
346+
new_tile_strides[i] *= tiles[0].dimension(0);
347+
}
348+
tiles[0] = ::xla::Tile({1, target_shape[1]});
349+
new_memref_ty = MemRefType::get(
350+
new_memref_ty.getShape(), new_memref_ty.getElementType(),
351+
TiledLayoutAttr::get(new_memref_ty.getContext(), tiles,
352+
new_tile_strides),
353+
new_memref_ty.getMemorySpace());
354+
arg_use_op = builder.create<tpu::ReinterpretCastOp>(val.getLoc(),
355+
new_memref_ty, val);
356+
val = arg_use_op->getResult(0);
357+
}
305358
// Some standard MLIR ops have static checks that seems unreasonable,
306359
// and we know they hold in the way they are used in Mosaic. Still,
307360
// verification with layouts likes to fail, because it can't statically
308361
// prove the properties.
309362
auto erase_op = builder.create<tpu::EraseLayoutOp>(
310-
arg.getLoc(),
363+
val.getLoc(),
311364
MemRefType::get(new_memref_ty.getShape(), memref_ty.getElementType(),
312365
/*layout=*/nullptr, new_memref_ty.getMemorySpace()),
313-
arg);
314-
arg.replaceAllUsesExcept(erase_op.getResult(), erase_op);
366+
val);
367+
if (!arg_use_op) {
368+
arg_use_op = erase_op;
369+
}
370+
arg.replaceAllUsesExcept(erase_op.getResult(), arg_use_op);
315371
}
316372
}
317373
f.setFunctionType(

jaxlib/mosaic/dialect/tpu/transforms/infer_vector_layout.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1283,7 +1283,8 @@ class VectorLayoutInferer {
12831283
layout_tiling, ImplicitDim::kNone));
12841284
} else if (bitwidth == 32 &&
12851285
canReinterpretToUntiledMemref(
1286-
src_ty, target_shape_, /*allow_minormost_padding=*/true) &&
1286+
op.getBase(), target_shape_,
1287+
/*allow_minormost_padding=*/true) &&
12871288
*(src_ty.getShape().end() - 2) > 1) {
12881289
// Since it is untiled, we can load from any arbitrary address which
12891290
// means we can always set the sublane offset to 0.
@@ -1620,7 +1621,8 @@ class VectorLayoutInferer {
16201621
// We accept padding in the minormost dim, because
16211622
// apply_vector_layout will properly mask stores.
16221623
canReinterpretToUntiledMemref(
1623-
ref_ty, target_shape_, /*allow_minormost_padding=*/true)) {
1624+
op.getBase(), target_shape_,
1625+
/*allow_minormost_padding=*/true)) {
16241626
// Since it is untiled, we can store to any arbitrary address which
16251627
// means the sublane offset can be any value and we can fold it to
16261628
// 2nd minor index.

jaxlib/mosaic/dialect/tpu/transforms/memory_space_specialization.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ LogicalResult specializeMemorySpace(TypedValue<MemRefType> value,
7070
to_update.pop_back();
7171
// Here we only have to handle the operations allowed on refs with
7272
// unspecified memory space.
73+
if (auto op = dyn_cast<tpu::ReinterpretCastOp>(some_op)) {
74+
updateResultFrom(op, op.getInput().getType());
75+
continue;
76+
}
7377
if (auto op = dyn_cast<tpu::MemRefSliceOp>(some_op)) {
7478
updateResultFrom(op, op.getMemRef().getType());
7579
continue;

jaxlib/mosaic/dialect/tpu/util.cc

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ limitations under the License.
2323
#include "llvm/Support/MathExtras.h"
2424
#include "absl/types/span.h"
2525
#include "mlir/include/mlir/IR/BuiltinTypes.h"
26+
#include "mlir/include/mlir/IR/Value.h"
27+
#include "mlir/include/mlir/IR/ValueRange.h"
2628
#include "mlir/include/mlir/Support/LLVM.h"
2729
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
2830

@@ -69,31 +71,74 @@ std::optional<std::pair<bool, bool>> isTransposedMatmul(
6971
return std::pair<bool, bool>{lhs_transposed, rhs_transposed};
7072
}
7173

72-
bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty,
74+
bool canReinterpretToUntiledMemref(TypedValue<MemRefType> tiled_memref,
7375
const std::array<int64_t, 2>& target_shape,
7476
bool allow_minormost_padding) {
77+
MemRefType tiled_memref_ty = tiled_memref.getType();
7578
auto tiled_layout =
7679
dyn_cast<tpu::TiledLayoutAttr>(tiled_memref_ty.getLayout());
80+
ValueRange dynamic_sizes = {};
81+
if (!tiled_layout) {
82+
if (auto erase_op = tiled_memref.getDefiningOp<tpu::EraseLayoutOp>()) {
83+
tiled_memref = erase_op.getOperand();
84+
tiled_memref_ty = tiled_memref.getType();
85+
tiled_layout =
86+
dyn_cast<tpu::TiledLayoutAttr>(tiled_memref_ty.getLayout());
87+
// TODO(b/375641258): Currently we rely on the pattern `slice ->
88+
// (squeeze)* -> eraseLayout` to get the dynamic sizes, but other patterns
89+
// may not work here: eg., slice -> eraseLayout -> reshape ->
90+
// eraseLayout`. We should fix this! For now, if we can not get the
91+
// expected dynamic sizes, we consider the memref cannot be reinterpreted
92+
// to untiled.
93+
Value ref = tiled_memref;
94+
while (auto squeeze_op = ref.getDefiningOp<tpu::MemRefSqueezeOp>()) {
95+
ref = squeeze_op.getInput();
96+
}
97+
if (auto slice_op = ref.getDefiningOp<tpu::MemRefSliceOp>()) {
98+
dynamic_sizes = slice_op.getDynamicSizes();
99+
}
100+
}
101+
}
77102
if (!tiled_layout) {
78103
// We expect the tiled memref to have a tiled layout.
79104
return false;
80105
}
106+
if (tiled_memref_ty.getNumDynamicDims() != dynamic_sizes.size()) {
107+
return false;
108+
}
81109
if (tiled_layout.getTiles().empty() ||
82110
tiled_layout.getTiles().front().dimensions().size() != 2 ||
83111
tiled_memref_ty.getRank() < 2) {
84-
// TODO(jevinjiang): Currently we only support >= 2D memref, we might
112+
// TODO(b/375642202): Currently we only support >= 2D memref, we might
85113
// need to handle 1D memref if we find a use case.
86114
return false;
87115
}
88-
if (!allow_minormost_padding &&
89-
*(tiled_memref_ty.getShape().end() - 1) != target_shape[1]) {
90-
return false;
91-
}
116+
auto rank = tiled_memref_ty.getRank();
92117
auto packing = 32 / tiled_memref_ty.getElementTypeBitWidth();
93-
return (*(tiled_memref_ty.getShape().end() - 1) <= target_shape[1] &&
94-
*(tiled_memref_ty.getShape().end() - 2) % packing == 0 &&
95-
*(tiled_layout.getTileStrides().end() - 1) == 1 &&
96-
*(tiled_layout.getTileStrides().end() - 2) == 1);
118+
if (tiled_memref_ty.isDynamicDim(rank - 1)) {
119+
// TODO(jevinjiang): we can still allow the minormost padding if we know the
120+
// max bound of the dynamic size is not larger than the target_shape[1].
121+
if (!isGuaranteedDivisible(dynamic_sizes.back(), target_shape[1])) {
122+
return false;
123+
}
124+
dynamic_sizes = dynamic_sizes.drop_back();
125+
} else {
126+
if (!allow_minormost_padding &&
127+
tiled_memref_ty.getShape()[rank - 1] != target_shape[1]) {
128+
return false;
129+
}
130+
}
131+
if (tiled_memref_ty.isDynamicDim(rank - 2)) {
132+
if (!isGuaranteedDivisible(dynamic_sizes.back(), packing)) {
133+
return false;
134+
}
135+
} else {
136+
if (tiled_memref_ty.getShape()[rank - 2] % packing != 0) {
137+
return false;
138+
}
139+
}
140+
// Check if the minormost dim has a single tile.
141+
return *(tiled_layout.getTileStrides().end() - 1) == 1 &&
142+
*(tiled_layout.getTileStrides().end() - 2) == 1;
97143
}
98-
99144
} // namespace mlir::tpu

jaxlib/mosaic/dialect/tpu/util.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
#define THIRD_PARTY_PY_JAX_JAXLIB_MOSAIC_DIALECT_TPU_UTIL_H_
33

44
#include <array>
5-
#include <cstddef>
65
#include <cstdint>
76
#include <sstream>
87
#include <string>
@@ -17,7 +16,7 @@
1716
#include "mlir/Support/LogicalResult.h"
1817
#include "absl/types/span.h"
1918
#include "jaxlib/mosaic/dialect/tpu/tpu_dialect.h"
20-
#include "tsl/platform/statusor.h"
19+
#include "mlir/include/mlir/IR/Value.h"
2120

2221
// TODO: Instead of CHECK_EQs, can we do something like TF_RET_CHECK but with
2322
// MLIR diagnostics?
@@ -112,7 +111,7 @@ std::optional<std::pair<bool, bool>> isTransposedMatmul(
112111
// considered as an untiled memref, except for potential padding in the
113112
// minormost dimension up to target_shape[1] (if allow_minormost_padding is
114113
// true).
115-
bool canReinterpretToUntiledMemref(MemRefType tiled_memref_ty,
114+
bool canReinterpretToUntiledMemref(TypedValue<MemRefType> tiled_memref,
116115
const std::array<int64_t, 2> &target_shape,
117116
bool allow_minormost_padding = false);
118117

tests/pallas/tpu_pallas_test.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1472,6 +1472,40 @@ def kernel(index, x, y, sem):
14721472
np.testing.assert_array_equal(y, i)
14731473
del y
14741474

1475+
def test_dynamic_dma_on_2nd_minor(self):
1476+
def kernel(array, data, index, size, _, sem):
1477+
pltpu.async_copy(
1478+
data.at[pl.ds(0, size[0])], array.at[pl.ds(index[0], size[0])], sem
1479+
).wait()
1480+
1481+
def run(array, data, index, size):
1482+
return pl.pallas_call(
1483+
kernel,
1484+
out_shape=array,
1485+
in_specs=[
1486+
pl.BlockSpec(memory_space=pltpu.ANY),
1487+
pl.BlockSpec(memory_space=pltpu.VMEM),
1488+
pl.BlockSpec(memory_space=pltpu.SMEM),
1489+
pl.BlockSpec(memory_space=pltpu.SMEM),
1490+
],
1491+
scratch_shapes=[
1492+
pltpu.SemaphoreType.DMA,
1493+
],
1494+
out_specs=pl.BlockSpec(memory_space=pltpu.ANY),
1495+
input_output_aliases={0: 0},
1496+
)(array, data, index, size)
1497+
1498+
array = jnp.zeros((1024, 128), jnp.int32)
1499+
data = jnp.ones((8, 128), jnp.int32)
1500+
index = jnp.array([3], jnp.int32)
1501+
size = jnp.array([5], jnp.int32)
1502+
1503+
expected = array.at[index[0] : index[0] + size[0]].set(
1504+
data[index[0] : index[0] + size[0]]
1505+
)
1506+
result = run(array, data, index, size)
1507+
np.testing.assert_array_equal(result, expected)
1508+
14751509

14761510
class PallasCallDMAInterpretTest(PallasCallDMATest):
14771511
INTERPRET = True

0 commit comments

Comments
 (0)