1313#include " iree/compiler/Codegen/Utils/EncodingUtils.h"
1414#include " iree/compiler/Codegen/Utils/Utils.h"
1515#include " iree/compiler/Dialect/Encoding/IR/EncodingOps.h"
16- #include " iree/compiler/Dialect/HAL/IR/HALTypes.h"
1716#include " iree/compiler/Dialect/LinalgExt/Utils/MatchUtils.h"
1817#include " iree/compiler/Dialect/TensorExt/IR/TensorExtOps.h"
1918#include " iree/compiler/Dialect/Util/IR/UtilOps.h"
@@ -858,9 +857,13 @@ class MaterializeContractionOp
858857 LogicalResult
859858 matchAndRewrite (linalg::LinalgOp op, ArrayRef<Value> operands,
860859 ConversionPatternRewriter &rewriter) const override {
861- if (!linalg::isaContractionOpInterface (op)) {
860+ // TODO(hanchung): Remove the check after moving other ops, e.g., fill,
861+ // generic, etc, lowering patterns to interface implementation.
862+ if (!linalg::isaContractionOpInterface (op) &&
863+ !IREE::LinalgExt::isaScaledContractionOpInterface (op)) {
862864 return rewriter.notifyMatchFailure (
863- op, " does not implement ContractionOpInterface" );
865+ op, " does not match linalg::isaContractionOpInterface and "
866+ " LinalgExt::isaScaledContractionOpInterface" );
864867 }
865868
866869 auto converter = static_cast <const MaterializeEncodingTypeConverter *>(
@@ -879,43 +882,6 @@ class MaterializeContractionOp
879882 }
880883};
881884
882- // / Pattern to convert scaled contraction operations.
883- class MaterializeScaledContractionOp
884- : public OpInterfaceConversionPattern<linalg::LinalgOp> {
885- public:
886- MaterializeScaledContractionOp (
887- const MaterializeEncodingTypeConverter &typeConverter,
888- MLIRContext *context, PatternBenefit benefit = 1 )
889- : OpInterfaceConversionPattern<linalg::LinalgOp>(typeConverter, context,
890- benefit) {}
891-
892- LogicalResult
893- matchAndRewrite (linalg::LinalgOp op, ArrayRef<Value> operands,
894- ConversionPatternRewriter &rewriter) const override {
895- if (!IREE::LinalgExt::isaScaledContractionOpInterface (op)) {
896- return rewriter.notifyMatchFailure (
897- op, " does not implement ScaledContractionOpInterface" );
898- }
899-
900- auto converter = static_cast <const MaterializeEncodingTypeConverter *>(
901- this ->getTypeConverter ());
902-
903- IREE::Encoding::LayoutMaterializerAttr layoutAttr =
904- converter->getLayoutAttr ();
905- SmallVector<Type> convertedResTypes;
906- for (Value init : op.getDpsInits ()) {
907- convertedResTypes.push_back (converter->convertType (init.getType ()));
908- }
909- Operation *newOp =
910- layoutAttr.lowerOp (rewriter, op, convertedResTypes, operands);
911- if (!newOp) {
912- return failure ();
913- }
914- rewriter.replaceOp (op, newOp->getResults ());
915- return success ();
916- }
917- };
918-
919885static bool isRankedTensorTypeWithEncoding (Type type) {
920886 auto rankedTensorType = dyn_cast<RankedTensorType>(type);
921887 if (!rankedTensorType) {
@@ -975,15 +941,15 @@ void populateMaterializeEncodingPatterns(
975941 isRankedTensorTypeWithEncoding);
976942 });
977943
978- patterns.insert <
979- MaterializeContractionOp, MaterializeScaledContractionOp ,
980- SetEncodingOpLoweringConversion, UnsetEncodingOpLoweringConversion ,
981- MaterializeDPSOperation<linalg::FillOp >,
982- MaterializeDPSOperation<linalg::GenericOp >,
983- MaterializeOperation<tensor::EmptyOp>, MaterializeOptimizationBarrierOp,
984- MaterializeTensorExtDispatchTensorLoadOp,
985- MaterializeTensorExtDispatchTensorStoreOp,
986- MaterializeInterfaceBindingEncoding, MaterializeFuncReturnOp>(
944+ patterns.insert <MaterializeContractionOp, SetEncodingOpLoweringConversion,
945+ UnsetEncodingOpLoweringConversion ,
946+ MaterializeDPSOperation<linalg::FillOp> ,
947+ MaterializeDPSOperation<linalg::GenericOp >,
948+ MaterializeOperation<tensor::EmptyOp >,
949+ MaterializeOptimizationBarrierOp,
950+ MaterializeTensorExtDispatchTensorLoadOp,
951+ MaterializeTensorExtDispatchTensorStoreOp,
952+ MaterializeInterfaceBindingEncoding, MaterializeFuncReturnOp>(
987953 typeConverter, context);
988954};
989955
0 commit comments