diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 9b29179f36871..c4c0497c2d1f0 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -351,15 +352,13 @@ struct VectorInsertStridedSliceOpConvert final } }; -template -struct VectorReductionPattern final - : public OpConversionPattern { +template +struct VectorReductionPatternBase : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + ConversionPatternRewriter &rewriter) const final { Type resultType = typeConverter->convertType(reduceOp.getType()); if (!resultType) return failure(); @@ -368,9 +367,22 @@ struct VectorReductionPattern final if (!srcVectorType || srcVectorType.getRank() != 1) return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source"); - // Extract all elements. + SmallVector extractedElements = + extractAllElements(reduceOp, adaptor, srcVectorType, rewriter); + + const auto &self = static_cast(*this); + + return self.reduceExtracted(reduceOp, extractedElements, resultType, + rewriter); + } + +private: + SmallVector + extractAllElements(vector::ReductionOp reduceOp, OpAdaptor adaptor, + VectorType srcVectorType, + ConversionPatternRewriter &rewriter) const { int numElements = srcVectorType.getDimSize(0); - SmallVector values; + SmallVector values; values.reserve(numElements + (adaptor.getAcc() != nullptr)); Location loc = reduceOp.getLoc(); for (int i = 0; i < numElements; ++i) { @@ -381,9 +393,26 @@ struct VectorReductionPattern final if (Value acc = adaptor.getAcc()) values.push_back(acc); - // Reduce them. - Value result = values.front(); - for (Value next : llvm::ArrayRef(values).drop_front()) { + return values; + } +}; + +#define VECTOR_REDUCTION_BASE \ + VectorReductionPatternBase> +template +struct VectorReductionPattern final : VECTOR_REDUCTION_BASE { + using Base = VECTOR_REDUCTION_BASE; + using Base::Base; + + LogicalResult reduceExtracted(vector::ReductionOp reduceOp, + ArrayRef extractedElements, + Type resultType, + ConversionPatternRewriter &rewriter) const { + mlir::Location loc = reduceOp->getLoc(); + Value result = extractedElements.front(); + for (Value next : llvm::ArrayRef(extractedElements).drop_front()) { switch (reduceOp.getKind()) { #define INT_AND_FLOAT_CASE(kind, iop, fop) \ @@ -403,10 +432,6 @@ struct VectorReductionPattern final INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp); INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp); - INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp); - INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp); - INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp); - INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp); INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp); INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp); INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp); @@ -416,6 +441,8 @@ struct VectorReductionPattern final case vector::CombiningKind::OR: case vector::CombiningKind::XOR: return rewriter.notifyMatchFailure(reduceOp, "unimplemented"); + default: + return rewriter.notifyMatchFailure(reduceOp, "not handled here"); } } @@ -423,6 +450,48 @@ struct VectorReductionPattern final return success(); } }; +#undef VECTOR_REDUCTION_BASE +#undef INT_AND_FLOAT_CASE +#undef INT_OR_FLOAT_CASE + +#define MIN_MAX_PATTERN_BASE \ + VectorReductionPatternBase< \ + VectorReductionFloatMinMax> +template +struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE { + using Base = MIN_MAX_PATTERN_BASE; + using Base::Base; + + LogicalResult reduceExtracted(vector::ReductionOp reduceOp, + ArrayRef extractedElements, + Type resultType, + ConversionPatternRewriter &rewriter) const { + mlir::Location loc = reduceOp->getLoc(); + Value result = extractedElements.front(); + for (Value next : llvm::ArrayRef(extractedElements).drop_front()) { + switch (reduceOp.getKind()) { + +#define INT_OR_FLOAT_CASE(kind, fop) \ + case vector::CombiningKind::kind: \ + result = rewriter.create(loc, resultType, result, next); \ + break + + INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp); + INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp); + INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp); + INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp); + + default: + return rewriter.notifyMatchFailure(reduceOp, "not handled here"); + } + } + + rewriter.replaceOp(reduceOp, result); + return success(); + } +}; +#undef MIN_MAX_PATTERN_BASE +#undef INT_OR_FLOAT_CASE class VectorSplatPattern final : public OpConversionPattern { public: @@ -604,25 +673,28 @@ struct VectorReductionToDotProd final : OpRewritePattern { }; } // namespace -#define CL_MAX_MIN_OPS \ - spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \ - spirv::CLSMaxOp, spirv::CLSMinOp +#define CL_INT_MAX_MIN_OPS \ + spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp + +#define GL_INT_MAX_MIN_OPS \ + spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp -#define GL_MAX_MIN_OPS \ - spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \ - spirv::GLSMaxOp, spirv::GLSMinOp +#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp +#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add, - VectorFmaOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern, - VectorReductionPattern, VectorShapeCast, - VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, - VectorSplatPattern>(typeConverter, patterns.getContext()); + patterns.add< + VectorBitcastConvert, VectorBroadcastConvert, + VectorExtractElementOpConvert, VectorExtractOpConvert, + VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, + VectorFmaOpConvert, VectorInsertElementOpConvert, + VectorInsertOpConvert, VectorReductionPattern, + VectorReductionPattern, + VectorReductionFloatMinMax, + VectorReductionFloatMinMax, VectorShapeCast, + VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, + VectorSplatPattern>(typeConverter, patterns.getContext()); } void mlir::populateVectorReductionToSPIRVDotProductPatterns(