Skip to content

Commit a03b09e

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Mosaic TPU] Propagate tpu.erase_layout through more TPU ops
PiperOrigin-RevId: 834765662
1 parent e257fed commit a03b09e

File tree

2 files changed

+89
-1
lines changed

2 files changed

+89
-1
lines changed

jaxlib/mosaic/dialect/tpu/tpu.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ def TPU_VectorStoreOp :TPU_Op<"vector_store", [DefaultMemWrite, AttrSizedOperand
299299
$base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `,` type($valueToStore) `,` type($mask)
300300
}];
301301
let hasVerifier = 1;
302+
let hasCanonicalizeMethod = 1;
302303
}
303304

304305
// tpu.vector_load loads a vector from memory into a register.
@@ -326,6 +327,7 @@ def TPU_VectorLoadOp :TPU_Op<"vector_load", [DefaultMemRead, AttrSizedOperandSeg
326327
$base `[` $indices `]` (`masked` $mask^)? attr-dict `:` type($base) `,` type($result) `,` type($mask)
327328
}];
328329
let hasVerifier = 1;
330+
let hasCanonicalizeMethod = 1;
329331
}
330332

331333
def TPU_StridedLoadOp : TPU_Op<"strided_load", [DefaultMemRead]> {
@@ -413,6 +415,7 @@ def TPU_VectorLoadIdxOp :TPU_Op<"vector_load_idx", [DefaultMemRead, AttrSizedOpe
413415
$base `[` $indices `]` (`masked` $mask^)? attr-dict `:` type($base) `[` type($indices) `]` `,` type($value) `,` type($mask)
414416
}];
415417
let hasVerifier = 1;
418+
let hasCanonicalizeMethod = 1;
416419
}
417420

418421
// tpu.vector_store_idx stores values to arbitrary locations in memory.
@@ -448,6 +451,7 @@ def TPU_VectorStoreIdxOp :TPU_Op<"vector_store_idx", [DefaultMemWrite, AttrSized
448451
$base `[` $indices `]` `,` $valueToStore (`masked` $mask^)? attr-dict `:` type($base) `[` type($indices) `]` `,` type($valueToStore) `,` type($mask)
449452
}];
450453
let hasVerifier = 1;
454+
let hasCanonicalizeMethod = 1;
451455
}
452456

453457
// TODO(jevinjiang): deprecate to use dynamic_rotate.
@@ -1079,6 +1083,7 @@ def TPU_ReinterpretCastOp : TPU_Op<"reinterpret_cast", [Pure]> {
10791083
$input attr-dict `:` type($input) `->` type($result)
10801084
}];
10811085
let hasVerifier = 1;
1086+
let hasCanonicalizeMethod = 1;
10821087
}
10831088

10841089
def TPU_AssumeLayoutOp : TPU_Op<"assume_layout", [Pure]> {
@@ -1099,6 +1104,7 @@ def TPU_EraseLayoutOp : TPU_Op<"erase_memref_layout", [Pure, InferTypeOpAdaptor]
10991104
let assemblyFormat = [{
11001105
$operand attr-dict `:` type($operand) `->` type($result)
11011106
}];
1107+
let hasFolder = 1;
11021108
}
11031109

11041110
// Returns the ID of the current device.
@@ -1206,6 +1212,7 @@ def TPU_EnqueueDMAOp : TPU_Op<"enqueue_dma", [AttrSizedOperandSegments]> {
12061212
attr-dict
12071213
}];
12081214
let hasVerifier = 1;
1215+
let hasCanonicalizeMethod = 1;
12091216
}
12101217

12111218
// A base class for all ops that need to differentiate between gather and
@@ -1260,6 +1267,7 @@ def TPU_EnqueueIndirectDMAOp : TPU_Op<"enqueue_indirect_dma">, IndirectDMAOp {
12601267
ArrayRef<int64_t> offsets_shape,
12611268
MemRefType operand_ty);
12621269
}];
1270+
let hasCanonicalizeMethod = 1;
12631271
}
12641272

12651273
// tpu.wait_dma2 waits for a DMA to complete.
@@ -1288,6 +1296,7 @@ def TPU_WaitDMA2Op : TPU_Op<"wait_dma2", [AttrSizedOperandSegments]> {
12881296
let builders = [
12891297
OpBuilder<(ins "Value":$semaphore, "Value":$src, "Value":$dst)>
12901298
];
1299+
let hasCanonicalizeMethod = 1;
12911300
}
12921301

12931302
// TODO(b/395630795): Remove after 2025-08-10.
@@ -1318,6 +1327,7 @@ def TPU_WaitIndirectDMAOp : TPU_Op<"wait_indirect_dma">, IndirectDMAOp {
13181327
attr-dict
13191328
}];
13201329
let hasVerifier = 1;
1330+
let hasCanonicalizeMethod = 1;
13211331
let extraClassDeclaration = extraBaseClassDeclaration;
13221332
}
13231333

jaxlib/mosaic/dialect/tpu/tpu_ops.cc

Lines changed: 79 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ limitations under the License.
4242
#include "mlir/IR/Diagnostics.h"
4343
#include "mlir/IR/IRMapping.h"
4444
#include "mlir/IR/Matchers.h"
45+
#include "mlir/IR/OpDefinition.h"
4546
#include "mlir/IR/OperationSupport.h"
4647
#include "mlir/IR/PatternMatch.h"
4748
#include "mlir/IR/Region.h"
@@ -58,6 +59,22 @@ namespace tpu {
5859

5960
namespace {
6061

62+
// This should only be used to canonicalize away EraseLayoutOps that feed ops
63+
// that only consume memrefs and don't return them.
64+
LogicalResult propagateTiledLayoutToConsumer(Operation* op,
65+
PatternRewriter& rewriter) {
66+
bool modified = false;
67+
for (unsigned int i = 0; i < op->getNumOperands(); ++i) {
68+
if (auto erase_layout_op =
69+
op->getOperand(i).getDefiningOp<tpu::EraseLayoutOp>()) {
70+
modified = true;
71+
rewriter.modifyOpInPlace(
72+
op, [&]() { op->setOperand(i, erase_layout_op.getOperand()); });
73+
}
74+
}
75+
return success(modified);
76+
}
77+
6178
llvm::RoundingMode convertTpuRoundingModeToLLVMIR(tpu::RoundingMode mode) {
6279
switch (mode) {
6380
case tpu::RoundingMode::kToNearestEven:
@@ -268,6 +285,8 @@ struct MemRefSliceFoldConstantDynamicDim
268285
op.getResult().setType(new_type);
269286
op.getDynamicSizesMutable().assign(new_dynamic_sizes);
270287
});
288+
mlir::OpBuilder::InsertionGuard guard(rewriter);
289+
rewriter.setInsertionPointAfter(op);
271290
auto cast_op = memref::CastOp::create(rewriter, op.getLoc(), old_type, op);
272291
rewriter.replaceAllUsesExcept(op, cast_op, cast_op);
273292
return success();
@@ -604,7 +623,7 @@ LogicalResult MemRefReshapeOp::canonicalize(MemRefReshapeOp op,
604623
}
605624
auto layout_ref = erase_layout_op.getOperand();
606625
auto layout_ty = layout_ref.getType();
607-
auto layout = dyn_cast<tpu::TiledLayoutAttr>(layout_ty.getLayout());
626+
auto layout = cast<tpu::TiledLayoutAttr>(layout_ty.getLayout());
608627
CHECK(!layout.getTiles().empty());
609628
auto tile = layout.getTiles().front().dimensions();
610629
auto new_tile_strides = ComputeTileStrides(dst_ty, tile);
@@ -788,6 +807,11 @@ LogicalResult VectorStoreOp::verify() {
788807
return verifyStoreOp(*this);
789808
}
790809

810+
LogicalResult VectorStoreOp::canonicalize(VectorStoreOp op,
811+
PatternRewriter& rewriter) {
812+
return propagateTiledLayoutToConsumer(op, rewriter);
813+
}
814+
791815
template <typename Op>
792816
LogicalResult verifyLoadOp(Op op) {
793817
MemRefType ref_ty = op.getBase().getType();
@@ -826,6 +850,11 @@ LogicalResult VectorLoadOp::verify() {
826850
return verifyLoadOp(*this);
827851
}
828852

853+
LogicalResult VectorLoadOp::canonicalize(VectorLoadOp op,
854+
PatternRewriter& rewriter) {
855+
return propagateTiledLayoutToConsumer(op, rewriter);
856+
}
857+
829858
LogicalResult VectorLoadIdxOp::verify() {
830859
VectorType value_ty = getResult().getType();
831860
MemRefType ref_ty = getBase().getType();
@@ -846,6 +875,11 @@ LogicalResult VectorLoadIdxOp::verify() {
846875
return verifyLoadOp(*this);
847876
}
848877

878+
LogicalResult VectorLoadIdxOp::canonicalize(VectorLoadIdxOp op,
879+
PatternRewriter& rewriter) {
880+
return propagateTiledLayoutToConsumer(op, rewriter);
881+
}
882+
849883
LogicalResult VectorStoreIdxOp::verify() {
850884
VectorType value_ty = getValueToStore().getType();
851885
MemRefType ref_ty = getBase().getType();
@@ -870,6 +904,11 @@ LogicalResult VectorStoreIdxOp::verify() {
870904
return verifyStoreOp(*this);
871905
}
872906

907+
LogicalResult VectorStoreIdxOp::canonicalize(VectorStoreIdxOp op,
908+
PatternRewriter& rewriter) {
909+
return propagateTiledLayoutToConsumer(op, rewriter);
910+
}
911+
873912
LogicalResult ReinterpretCastOp::verify() {
874913
auto source_type = getMemRefType(getInput());
875914
auto target_type = getType();
@@ -881,6 +920,17 @@ LogicalResult ReinterpretCastOp::verify() {
881920
return success();
882921
}
883922

923+
LogicalResult ReinterpretCastOp::canonicalize(ReinterpretCastOp op,
924+
PatternRewriter& rewriter) {
925+
if (auto erase_layout_op = op.getInput().getDefiningOp<EraseLayoutOp>()) {
926+
rewriter.modifyOpInPlace(op, [&]() {
927+
op.getInputMutable().assign(erase_layout_op.getOperand());
928+
});
929+
return success();
930+
}
931+
return failure();
932+
}
933+
884934
LogicalResult EraseLayoutOp::inferReturnTypes(
885935
MLIRContext* context, std::optional<Location> location,
886936
EraseLayoutOp::Adaptor adaptor,
@@ -891,6 +941,14 @@ LogicalResult EraseLayoutOp::inferReturnTypes(
891941
return success();
892942
}
893943

944+
OpFoldResult EraseLayoutOp::fold(FoldAdaptor op) {
945+
// If the operand has no interesting layout then there's no need to erase it.
946+
if (getOperand().getType().getLayout().isIdentity()) {
947+
return op.getOperand();
948+
}
949+
return OpFoldResult();
950+
}
951+
894952
template <typename Op>
895953
LogicalResult verifyRotateOp(Op op) {
896954
auto vty = op.getResult().getType();
@@ -1371,6 +1429,11 @@ LogicalResult EnqueueDMAOp::verify() {
13711429
return success();
13721430
}
13731431

1432+
LogicalResult EnqueueDMAOp::canonicalize(EnqueueDMAOp op,
1433+
PatternRewriter& rewriter) {
1434+
return propagateTiledLayoutToConsumer(op, rewriter);
1435+
}
1436+
13741437
LogicalResult EnqueueIndirectDMAOp::verifyGather(
13751438
MemRefType operand_ty, ArrayRef<int64_t> offsets_shape,
13761439
MemRefType result_ty) {
@@ -1550,6 +1613,11 @@ LogicalResult EnqueueIndirectDMAOp::verify() {
15501613
/*operand_ty=*/target_ty);
15511614
}
15521615

1616+
LogicalResult EnqueueIndirectDMAOp::canonicalize(EnqueueIndirectDMAOp op,
1617+
PatternRewriter& rewriter) {
1618+
return propagateTiledLayoutToConsumer(op, rewriter);
1619+
}
1620+
15531621
// TODO(b/395630795): Remove after 2025-08-10.
15541622
LogicalResult WaitDMAOp::verify() {
15551623
auto sem_type = getMemRefType(getSemaphore());
@@ -1573,6 +1641,11 @@ LogicalResult WaitDMA2Op::verify() {
15731641
return success();
15741642
}
15751643

1644+
LogicalResult WaitDMA2Op::canonicalize(WaitDMA2Op op,
1645+
PatternRewriter& rewriter) {
1646+
return propagateTiledLayoutToConsumer(op, rewriter);
1647+
}
1648+
15761649
FailureOr<bool> WaitIndirectDMAOp::isGather() {
15771650
return mlir::tpu::isGather(*getOperation(), getSrc(), getDst());
15781651
}
@@ -1593,6 +1666,11 @@ LogicalResult WaitIndirectDMAOp::verify() {
15931666
return isGather();
15941667
}
15951668

1669+
LogicalResult WaitIndirectDMAOp::canonicalize(WaitIndirectDMAOp op,
1670+
PatternRewriter& rewriter) {
1671+
return propagateTiledLayoutToConsumer(op, rewriter);
1672+
}
1673+
15961674
LogicalResult RegionOp::verify() {
15971675
for (auto result_type : getResultTypes()) {
15981676
if (!isa<FloatType, IntegerType, VectorType, IndexType>(result_type)) {

0 commit comments

Comments
 (0)