Skip to content

Commit df4c6ce

Browse files
authored
[DT][NFC] Refactor linalg.fill/generic op lowering to interface implementation. (#22343)
The revision deletes two patterns and moves the implementation to interface methods: - MaterializeDPSOperation<linalg::FillOp> - MaterializeDPSOperation<linalg::GenericOp> The lowerGenericOpWithResolvedLayouts method is very similar to the original lowerGenericOpWithEncoding. The new method takes the same input arguments like `lowerOp` method with an additional `LayoutMaterializerAttr` attribute, because it needs the packing info to generate indexing maps. In the new implementation, it uses `getEncodingInfoFromLayout` method directly while the method in the type converter is just a wrapper. I.e., it reduces the dependency from the type converter. See https://gist.github.com/hanhanW/e89aa9e5052f8db37b2543fa368ed605 for the diff. It is a step towards #20160, as the padding resolver no longer rely on "fallback" path to drop the encodings. It has its own implementation that uses `clone`. --------- Signed-off-by: hanhanW <[email protected]>
1 parent b9a06f6 commit df4c6ce

File tree

7 files changed

+363
-379
lines changed

7 files changed

+363
-379
lines changed

compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp

Lines changed: 7 additions & 356 deletions
Large diffs are not rendered by default.

compiler/src/iree/compiler/Codegen/ExternalInterfaces/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ iree_compiler_cc_library(
4444
"@llvm-project//mlir:DialectUtils",
4545
"@llvm-project//mlir:IR",
4646
"@llvm-project//mlir:LinalgDialect",
47+
"@llvm-project//mlir:LinalgTransforms",
4748
"@llvm-project//mlir:TensorDialect",
4849
],
4950
)

compiler/src/iree/compiler/Codegen/ExternalInterfaces/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ iree_cc_library(
3030
MLIRArithDialect
3131
MLIRIR
3232
MLIRLinalgDialect
33+
MLIRLinalgTransforms
3334
MLIRTensorDialect
3435
iree::compiler::Codegen::Dialect::CPU::IR::IREECPUDialect
3536
iree::compiler::Codegen::Dialect::Codegen::IR::IREECodegenDialect

compiler/src/iree/compiler/Codegen/ExternalInterfaces/CPUEncodingExternalModels.cpp

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
#include "iree/compiler/Codegen/ExternalInterfaces/Utils.h"
4444
#include "iree/compiler/Codegen/Utils/CPUUtils.h"
4545
#include "iree/compiler/Codegen/Utils/Utils.h"
46-
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
4746
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
4847
#include "iree/compiler/Dialect/Encoding/Utils/Utils.h"
4948
#include "llvm/Support/DebugLog.h"
@@ -284,11 +283,11 @@ TileMxNxK chooseMatmulTile(ArrayRef<TileMxNxK> enumeratedTiles,
284283
return bestRatedTile;
285284
}
286285

287-
FailureOr<Operation *> lowerContractionOpWithEncoding(
286+
Operation *lowerContractionOpWithEncoding(
288287
OpBuilder &builder, linalg::LinalgOp linalgOp, ValueRange operands,
289288
IREE::Encoding::LayoutMaterializerAttr layoutAttr) {
290289
if (!linalgOp.hasPureTensorSemantics()) {
291-
return failure();
290+
return nullptr;
292291
}
293292

294293
auto inputs = linalgOp.getDpsInputOperands();
@@ -301,14 +300,14 @@ FailureOr<Operation *> lowerContractionOpWithEncoding(
301300
auto rhsEncoding = IREE::Encoding::getEncodingAttr(rhsType);
302301
auto resultEncoding = IREE::Encoding::getEncodingAttr(resultType);
303302
if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
304-
return failure();
303+
return nullptr;
305304
}
306305

307306
if (lhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_LHS ||
308307
rhsEncoding.getOperandIndex().getValue() != IREE::Encoding::MATMUL_RHS ||
309308
resultEncoding.getOperandIndex().getValue() !=
310309
IREE::Encoding::MATMUL_RESULT) {
311-
return failure();
310+
return nullptr;
312311
}
313312

314313
MaterializeEncodingInfo encodingInfo = {};
@@ -725,11 +724,21 @@ struct CPUEncodingResolverMaterializerAttr final
725724
if (!linalgOp) {
726725
return nullptr;
727726
}
728-
729-
FailureOr<Operation *> newOp = lowerContractionOpWithEncoding(
730-
b, linalgOp, convertedOperands,
731-
cast<IREE::Encoding::LayoutMaterializerAttr>(layoutAttr));
732-
return newOp.value_or(nullptr);
727+
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
728+
return lowerFillOpWithResolvedLayouts(b, fillOp, convertedResTypes,
729+
convertedOperands);
730+
}
731+
if (linalg::isaContractionOpInterface(linalgOp)) {
732+
return lowerContractionOpWithEncoding(
733+
b, linalgOp, convertedOperands,
734+
cast<IREE::Encoding::LayoutMaterializerAttr>(layoutAttr));
735+
}
736+
if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
737+
return lowerGenericOpWithResolvedLayouts(
738+
b, genericOp, convertedResTypes, convertedOperands,
739+
cast<IREE::Encoding::LayoutMaterializerAttr>(attr));
740+
}
741+
return nullptr;
733742
}
734743
};
735744

@@ -869,11 +878,21 @@ struct VMVXEncodingResolverMaterializerAttr final
869878
if (!linalgOp) {
870879
return nullptr;
871880
}
872-
873-
FailureOr<Operation *> newOp = lowerContractionOpWithEncoding(
874-
b, linalgOp, convertedOperands,
875-
cast<IREE::Encoding::LayoutMaterializerAttr>(layoutAttr));
876-
return newOp.value_or(nullptr);
881+
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
882+
return lowerFillOpWithResolvedLayouts(b, fillOp, convertedResTypes,
883+
convertedOperands);
884+
}
885+
if (linalg::isaContractionOpInterface(linalgOp)) {
886+
return lowerContractionOpWithEncoding(
887+
b, linalgOp, convertedOperands,
888+
cast<IREE::Encoding::LayoutMaterializerAttr>(layoutAttr));
889+
}
890+
if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
891+
return lowerGenericOpWithResolvedLayouts(
892+
b, genericOp, convertedResTypes, convertedOperands,
893+
cast<IREE::Encoding::LayoutMaterializerAttr>(attr));
894+
}
895+
return nullptr;
877896
}
878897
};
879898

compiler/src/iree/compiler/Codegen/ExternalInterfaces/GPUEncodingExternalModels.cpp

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,9 @@
3333
#include "iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.h"
3434
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
3535
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
36-
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
3736
#include "iree/compiler/Codegen/Dialect/GPU/TargetUtils/KnownTargets.h"
3837
#include "iree/compiler/Codegen/ExternalInterfaces/Utils.h"
3938
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
40-
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
4139
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"
4240
#include "iree/compiler/Dialect/Encoding/Utils/Utils.h"
4341
#include "iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.h"
@@ -335,10 +333,6 @@ static Operation *lowerContractionOrScaledContractionOpToInnerTiledOp(
335333
if (!linalgOp.hasPureTensorSemantics()) {
336334
return nullptr;
337335
}
338-
if (!linalg::isaContractionOpInterface(linalgOp) &&
339-
!IREE::LinalgExt::isaScaledContractionOpInterface(linalgOp)) {
340-
return nullptr;
341-
}
342336

343337
SmallVector<Value> inputs = linalgOp.getDpsInputs();
344338
SmallVector<Value> outputs = linalgOp.getDpsInits();
@@ -517,8 +511,21 @@ struct GPUEncodingResolverMaterializerAttr
517511
if (!linalgOp) {
518512
return nullptr;
519513
}
520-
return lowerContractionOrScaledContractionOpToInnerTiledOp(
521-
b, linalgOp, convertedOperands, resolverAttr);
514+
if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
515+
return lowerFillOpWithResolvedLayouts(b, fillOp, convertedResTypes,
516+
convertedOperands);
517+
}
518+
if (linalg::isaContractionOpInterface(linalgOp) ||
519+
IREE::LinalgExt::isaScaledContractionOpInterface(linalgOp)) {
520+
return lowerContractionOrScaledContractionOpToInnerTiledOp(
521+
b, linalgOp, convertedOperands, resolverAttr);
522+
}
523+
if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
524+
return lowerGenericOpWithResolvedLayouts(
525+
b, genericOp, convertedResTypes, convertedOperands,
526+
cast<IREE::Encoding::LayoutMaterializerAttr>(attr));
527+
}
528+
return nullptr;
522529
}
523530
};
524531

0 commit comments

Comments
 (0)