Skip to content

Commit 59ce62f

Browse files
authored
[DT][NFC] Collapse MaterializeScaledContractionOp into generic pattern. (#22340)
The op lowering is driven by interface methods -- `lowerOp`, so we don't need a separate pattern for scaled matmul cases. https://github.com/iree-org/iree/blob/d4d74cb2f1f71fe8b3aaf7a41e5d5895cb726413/compiler/src/iree/compiler/Codegen/ExternalInterfaces/GPUEncodingExternalModels.cpp#L509-L523 It is a step towards #20160 Signed-off-by: hanhanW <[email protected]>
1 parent 189dc86 commit 59ce62f

File tree

1 file changed

+15
-49
lines changed

1 file changed

+15
-49
lines changed

compiler/src/iree/compiler/Codegen/Common/MaterializeEncodingPatterns.cpp

Lines changed: 15 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
#include "iree/compiler/Codegen/Utils/EncodingUtils.h"
1414
#include "iree/compiler/Codegen/Utils/Utils.h"
1515
#include "iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
16-
#include "iree/compiler/Dialect/HAL/IR/HALTypes.h"
1716
#include "iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.h"
1817
#include "iree/compiler/Dialect/TensorExt/IR/TensorExtOps.h"
1918
#include "iree/compiler/Dialect/Util/IR/UtilOps.h"
@@ -858,9 +857,13 @@ class MaterializeContractionOp
858857
LogicalResult
859858
matchAndRewrite(linalg::LinalgOp op, ArrayRef<Value> operands,
860859
ConversionPatternRewriter &rewriter) const override {
861-
if (!linalg::isaContractionOpInterface(op)) {
860+
// TODO(hanchung): Remove the check after moving other ops, e.g., fill,
861+
// generic, etc, lowering patterns to interface implementation.
862+
if (!linalg::isaContractionOpInterface(op) &&
863+
!IREE::LinalgExt::isaScaledContractionOpInterface(op)) {
862864
return rewriter.notifyMatchFailure(
863-
op, "does not implement ContractionOpInterface");
865+
op, "does not match linalg::isaContractionOpInterface and "
866+
"LinalgExt::isaScaledContractionOpInterface");
864867
}
865868

866869
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
@@ -879,43 +882,6 @@ class MaterializeContractionOp
879882
}
880883
};
881884

882-
/// Pattern to convert scaled contraction operations.
883-
class MaterializeScaledContractionOp
884-
: public OpInterfaceConversionPattern<linalg::LinalgOp> {
885-
public:
886-
MaterializeScaledContractionOp(
887-
const MaterializeEncodingTypeConverter &typeConverter,
888-
MLIRContext *context, PatternBenefit benefit = 1)
889-
: OpInterfaceConversionPattern<linalg::LinalgOp>(typeConverter, context,
890-
benefit) {}
891-
892-
LogicalResult
893-
matchAndRewrite(linalg::LinalgOp op, ArrayRef<Value> operands,
894-
ConversionPatternRewriter &rewriter) const override {
895-
if (!IREE::LinalgExt::isaScaledContractionOpInterface(op)) {
896-
return rewriter.notifyMatchFailure(
897-
op, "does not implement ScaledContractionOpInterface");
898-
}
899-
900-
auto converter = static_cast<const MaterializeEncodingTypeConverter *>(
901-
this->getTypeConverter());
902-
903-
IREE::Encoding::LayoutMaterializerAttr layoutAttr =
904-
converter->getLayoutAttr();
905-
SmallVector<Type> convertedResTypes;
906-
for (Value init : op.getDpsInits()) {
907-
convertedResTypes.push_back(converter->convertType(init.getType()));
908-
}
909-
Operation *newOp =
910-
layoutAttr.lowerOp(rewriter, op, convertedResTypes, operands);
911-
if (!newOp) {
912-
return failure();
913-
}
914-
rewriter.replaceOp(op, newOp->getResults());
915-
return success();
916-
}
917-
};
918-
919885
static bool isRankedTensorTypeWithEncoding(Type type) {
920886
auto rankedTensorType = dyn_cast<RankedTensorType>(type);
921887
if (!rankedTensorType) {
@@ -975,15 +941,15 @@ void populateMaterializeEncodingPatterns(
975941
isRankedTensorTypeWithEncoding);
976942
});
977943

978-
patterns.insert<
979-
MaterializeContractionOp, MaterializeScaledContractionOp,
980-
SetEncodingOpLoweringConversion, UnsetEncodingOpLoweringConversion,
981-
MaterializeDPSOperation<linalg::FillOp>,
982-
MaterializeDPSOperation<linalg::GenericOp>,
983-
MaterializeOperation<tensor::EmptyOp>, MaterializeOptimizationBarrierOp,
984-
MaterializeTensorExtDispatchTensorLoadOp,
985-
MaterializeTensorExtDispatchTensorStoreOp,
986-
MaterializeInterfaceBindingEncoding, MaterializeFuncReturnOp>(
944+
patterns.insert<MaterializeContractionOp, SetEncodingOpLoweringConversion,
945+
UnsetEncodingOpLoweringConversion,
946+
MaterializeDPSOperation<linalg::FillOp>,
947+
MaterializeDPSOperation<linalg::GenericOp>,
948+
MaterializeOperation<tensor::EmptyOp>,
949+
MaterializeOptimizationBarrierOp,
950+
MaterializeTensorExtDispatchTensorLoadOp,
951+
MaterializeTensorExtDispatchTensorStoreOp,
952+
MaterializeInterfaceBindingEncoding, MaterializeFuncReturnOp>(
987953
typeConverter, context);
988954
};
989955

0 commit comments

Comments
 (0)