2020#include " mlir/Dialect/Vector/IR/VectorOps.h"
2121#include " mlir/IR/BuiltinAttributes.h"
2222#include " mlir/IR/BuiltinTypes.h"
23+ #include " mlir/IR/Location.h"
2324#include " mlir/IR/Matchers.h"
2425#include " mlir/IR/PatternMatch.h"
2526#include " mlir/IR/TypeUtilities.h"
@@ -351,15 +352,13 @@ struct VectorInsertStridedSliceOpConvert final
351352 }
352353};
353354
354- template <class SPIRVFMaxOp , class SPIRVFMinOp , class SPIRVUMaxOp ,
355- class SPIRVUMinOp , class SPIRVSMaxOp , class SPIRVSMinOp >
356- struct VectorReductionPattern final
357- : public OpConversionPattern<vector::ReductionOp> {
355+ template <typename Derived>
356+ struct VectorReductionPatternBase : OpConversionPattern<vector::ReductionOp> {
358357 using OpConversionPattern::OpConversionPattern;
359358
360359 LogicalResult
361360 matchAndRewrite (vector::ReductionOp reduceOp, OpAdaptor adaptor,
362- ConversionPatternRewriter &rewriter) const override {
361+ ConversionPatternRewriter &rewriter) const final {
363362 Type resultType = typeConverter->convertType (reduceOp.getType ());
364363 if (!resultType)
365364 return failure ();
@@ -368,9 +367,22 @@ struct VectorReductionPattern final
368367 if (!srcVectorType || srcVectorType.getRank () != 1 )
369368 return rewriter.notifyMatchFailure (reduceOp, " not 1-D vector source" );
370369
371- // Extract all elements.
370+ SmallVector<Value> extractedElements =
371+ extractAllElements (reduceOp, adaptor, srcVectorType, rewriter);
372+
373+ const auto &self = static_cast <const Derived &>(*this );
374+
375+ return self.reduceExtracted (reduceOp, extractedElements, resultType,
376+ rewriter);
377+ }
378+
379+ private:
380+ SmallVector<Value>
381+ extractAllElements (vector::ReductionOp reduceOp, OpAdaptor adaptor,
382+ VectorType srcVectorType,
383+ ConversionPatternRewriter &rewriter) const {
372384 int numElements = srcVectorType.getDimSize (0 );
373- SmallVector<Value, 4 > values;
385+ SmallVector<Value> values;
374386 values.reserve (numElements + (adaptor.getAcc () != nullptr ));
375387 Location loc = reduceOp.getLoc ();
376388 for (int i = 0 ; i < numElements; ++i) {
@@ -381,9 +393,26 @@ struct VectorReductionPattern final
381393 if (Value acc = adaptor.getAcc ())
382394 values.push_back (acc);
383395
384- // Reduce them.
385- Value result = values.front ();
386- for (Value next : llvm::ArrayRef (values).drop_front ()) {
396+ return values;
397+ }
398+ };
399+
400+ #define VECTOR_REDUCTION_BASE \
401+ VectorReductionPatternBase<VectorReductionPattern<SPIRVUMaxOp, SPIRVUMinOp, \
402+ SPIRVSMaxOp, SPIRVSMinOp>>
403+ template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
404+ typename SPIRVSMinOp>
405+ struct VectorReductionPattern final : VECTOR_REDUCTION_BASE {
406+ using Base = VECTOR_REDUCTION_BASE;
407+ using Base::Base;
408+
409+ LogicalResult reduceExtracted (vector::ReductionOp reduceOp,
410+ ArrayRef<Value> extractedElements,
411+ Type resultType,
412+ ConversionPatternRewriter &rewriter) const {
413+ mlir::Location loc = reduceOp->getLoc ();
414+ Value result = extractedElements.front ();
415+ for (Value next : llvm::ArrayRef (extractedElements).drop_front ()) {
387416 switch (reduceOp.getKind ()) {
388417
389418#define INT_AND_FLOAT_CASE (kind, iop, fop ) \
@@ -403,10 +432,6 @@ struct VectorReductionPattern final
403432
404433 INT_AND_FLOAT_CASE (ADD, IAddOp, FAddOp);
405434 INT_AND_FLOAT_CASE (MUL, IMulOp, FMulOp);
406- INT_OR_FLOAT_CASE (MAXIMUMF, SPIRVFMaxOp);
407- INT_OR_FLOAT_CASE (MINIMUMF, SPIRVFMinOp);
408- INT_OR_FLOAT_CASE (MAXF, SPIRVFMaxOp);
409- INT_OR_FLOAT_CASE (MINF, SPIRVFMinOp);
410435 INT_OR_FLOAT_CASE (MINUI, SPIRVUMinOp);
411436 INT_OR_FLOAT_CASE (MINSI, SPIRVSMinOp);
412437 INT_OR_FLOAT_CASE (MAXUI, SPIRVUMaxOp);
@@ -416,13 +441,57 @@ struct VectorReductionPattern final
416441 case vector::CombiningKind::OR:
417442 case vector::CombiningKind::XOR:
418443 return rewriter.notifyMatchFailure (reduceOp, " unimplemented" );
444+ default :
445+ return rewriter.notifyMatchFailure (reduceOp, " not handled here" );
419446 }
420447 }
421448
422449 rewriter.replaceOp (reduceOp, result);
423450 return success ();
424451 }
425452};
453+ #undef VECTOR_REDUCTION_BASE
454+ #undef INT_AND_FLOAT_CASE
455+ #undef INT_OR_FLOAT_CASE
456+
457+ #define MIN_MAX_PATTERN_BASE \
458+ VectorReductionPatternBase< \
459+ VectorReductionFloatMinMax<SPIRVFMaxOp, SPIRVFMinOp>>
460+ template <class SPIRVFMaxOp , class SPIRVFMinOp >
461+ struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE {
462+ using Base = MIN_MAX_PATTERN_BASE;
463+ using Base::Base;
464+
465+ LogicalResult reduceExtracted (vector::ReductionOp reduceOp,
466+ ArrayRef<Value> extractedElements,
467+ Type resultType,
468+ ConversionPatternRewriter &rewriter) const {
469+ mlir::Location loc = reduceOp->getLoc ();
470+ Value result = extractedElements.front ();
471+ for (Value next : llvm::ArrayRef (extractedElements).drop_front ()) {
472+ switch (reduceOp.getKind ()) {
473+
474+ #define INT_OR_FLOAT_CASE (kind, fop ) \
475+ case vector::CombiningKind::kind: \
476+ result = rewriter.create <fop>(loc, resultType, result, next); \
477+ break
478+
479+ INT_OR_FLOAT_CASE (MAXIMUMF, SPIRVFMaxOp);
480+ INT_OR_FLOAT_CASE (MINIMUMF, SPIRVFMinOp);
481+ INT_OR_FLOAT_CASE (MAXF, SPIRVFMaxOp);
482+ INT_OR_FLOAT_CASE (MINF, SPIRVFMinOp);
483+
484+ default :
485+ return rewriter.notifyMatchFailure (reduceOp, " not handled here" );
486+ }
487+ }
488+
489+ rewriter.replaceOp (reduceOp, result);
490+ return success ();
491+ }
492+ };
493+ #undef MIN_MAX_PATTERN_BASE
494+ #undef INT_OR_FLOAT_CASE
426495
427496class VectorSplatPattern final : public OpConversionPattern<vector::SplatOp> {
428497public:
@@ -604,25 +673,28 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
604673};
605674
606675} // namespace
607- #define CL_MAX_MIN_OPS \
608- spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
609- spirv::CLSMaxOp, spirv::CLSMinOp
676+ #define CL_INT_MAX_MIN_OPS \
677+ spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
678+
679+ #define GL_INT_MAX_MIN_OPS \
680+ spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
610681
611- #define GL_MAX_MIN_OPS \
612- spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
613- spirv::GLSMaxOp, spirv::GLSMinOp
682+ #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
683+ #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
614684
615685void mlir::populateVectorToSPIRVPatterns (SPIRVTypeConverter &typeConverter,
616686 RewritePatternSet &patterns) {
617- patterns.add <VectorBitcastConvert, VectorBroadcastConvert,
618- VectorExtractElementOpConvert, VectorExtractOpConvert,
619- VectorExtractStridedSliceOpConvert,
620- VectorFmaOpConvert<spirv::GLFmaOp>,
621- VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
622- VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
623- VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
624- VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
625- VectorSplatPattern>(typeConverter, patterns.getContext ());
687+ patterns.add <
688+ VectorBitcastConvert, VectorBroadcastConvert,
689+ VectorExtractElementOpConvert, VectorExtractOpConvert,
690+ VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
691+ VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
692+ VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
693+ VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
694+ VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
695+ VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
696+ VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
697+ VectorSplatPattern>(typeConverter, patterns.getContext ());
626698}
627699
628700void mlir::populateVectorReductionToSPIRVDotProductPatterns (
0 commit comments