2020#include " mlir/Dialect/Vector/IR/VectorOps.h"
2121#include " mlir/IR/BuiltinAttributes.h"
2222#include " mlir/IR/BuiltinTypes.h"
23+ #include " mlir/IR/Location.h"
2324#include " mlir/IR/Matchers.h"
2425#include " mlir/IR/PatternMatch.h"
2526#include " mlir/IR/TypeUtilities.h"
2627#include " mlir/Support/LogicalResult.h"
2728#include " mlir/Transforms/DialectConversion.h"
2829#include " llvm/ADT/ArrayRef.h"
2930#include " llvm/ADT/STLExtras.h"
31+ #include " llvm/ADT/SmallVector.h"
3032#include " llvm/ADT/SmallVectorExtras.h"
3133#include " llvm/Support/FormatVariadic.h"
3234#include < cassert>
@@ -351,39 +353,64 @@ struct VectorInsertStridedSliceOpConvert final
351353 }
352354};
353355
354- template <class SPIRVFMaxOp , class SPIRVFMinOp , class SPIRVUMaxOp ,
355- class SPIRVUMinOp , class SPIRVSMaxOp , class SPIRVSMinOp >
356- struct VectorReductionPattern final
357- : public OpConversionPattern<vector::ReductionOp> {
356+ static SmallVector<Value> extractAllElements (
357+ vector::ReductionOp reduceOp, vector::ReductionOp::Adaptor adaptor,
358+ VectorType srcVectorType, ConversionPatternRewriter &rewriter) {
359+ int numElements = static_cast <int >(srcVectorType.getDimSize (0 ));
360+ SmallVector<Value> values;
361+ values.reserve (numElements + (adaptor.getAcc () ? 1 : 0 ));
362+ Location loc = reduceOp.getLoc ();
363+
364+ for (int i = 0 ; i < numElements; ++i) {
365+ values.push_back (rewriter.create <spirv::CompositeExtractOp>(
366+ loc, srcVectorType.getElementType (), adaptor.getVector (),
367+ rewriter.getI32ArrayAttr ({i})));
368+ }
369+ if (Value acc = adaptor.getAcc ())
370+ values.push_back (acc);
371+
372+ return values;
373+ }
374+
375+ struct ReductionRewriteInfo {
376+ Type resultType;
377+ SmallVector<Value> extractedElements;
378+ };
379+
380+ FailureOr<ReductionRewriteInfo> static getReductionInfo (
381+ vector::ReductionOp op, vector::ReductionOp::Adaptor adaptor,
382+ ConversionPatternRewriter &rewriter, const TypeConverter &typeConverter) {
383+ Type resultType = typeConverter.convertType (op.getType ());
384+ if (!resultType)
385+ return failure ();
386+
387+ auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector ().getType ());
388+ if (!srcVectorType || srcVectorType.getRank () != 1 )
389+ return rewriter.notifyMatchFailure (op, " not a 1-D vector source" );
390+
391+ SmallVector<Value> extractedElements =
392+ extractAllElements (op, adaptor, srcVectorType, rewriter);
393+
394+ return ReductionRewriteInfo{resultType, std::move (extractedElements)};
395+ }
396+
397+ template <typename SPIRVUMaxOp, typename SPIRVUMinOp, typename SPIRVSMaxOp,
398+ typename SPIRVSMinOp>
399+ struct VectorReductionPattern final : OpConversionPattern<vector::ReductionOp> {
358400 using OpConversionPattern::OpConversionPattern;
359401
360402 LogicalResult
361403 matchAndRewrite (vector::ReductionOp reduceOp, OpAdaptor adaptor,
362404 ConversionPatternRewriter &rewriter) const override {
363- Type resultType = typeConverter->convertType (reduceOp.getType ());
364- if (!resultType)
405+ auto reductionInfo =
406+ getReductionInfo (reduceOp, adaptor, rewriter, *getTypeConverter ());
407+ if (failed (reductionInfo))
365408 return failure ();
366409
367- auto srcVectorType = dyn_cast<VectorType>(adaptor.getVector ().getType ());
368- if (!srcVectorType || srcVectorType.getRank () != 1 )
369- return rewriter.notifyMatchFailure (reduceOp, " not 1-D vector source" );
370-
371- // Extract all elements.
372- int numElements = srcVectorType.getDimSize (0 );
373- SmallVector<Value, 4 > values;
374- values.reserve (numElements + (adaptor.getAcc () != nullptr ));
375- Location loc = reduceOp.getLoc ();
376- for (int i = 0 ; i < numElements; ++i) {
377- values.push_back (rewriter.create <spirv::CompositeExtractOp>(
378- loc, srcVectorType.getElementType (), adaptor.getVector (),
379- rewriter.getI32ArrayAttr ({i})));
380- }
381- if (Value acc = adaptor.getAcc ())
382- values.push_back (acc);
383-
384- // Reduce them.
385- Value result = values.front ();
386- for (Value next : llvm::ArrayRef (values).drop_front ()) {
410+ auto [resultType, extractedElements] = *reductionInfo;
411+ Location loc = reduceOp->getLoc ();
412+ Value result = extractedElements.front ();
413+ for (Value next : llvm::drop_begin (extractedElements)) {
387414 switch (reduceOp.getKind ()) {
388415
389416#define INT_AND_FLOAT_CASE (kind, iop, fop ) \
@@ -403,10 +430,6 @@ struct VectorReductionPattern final
403430
404431 INT_AND_FLOAT_CASE (ADD, IAddOp, FAddOp);
405432 INT_AND_FLOAT_CASE (MUL, IMulOp, FMulOp);
406- INT_OR_FLOAT_CASE (MAXIMUMF, SPIRVFMaxOp);
407- INT_OR_FLOAT_CASE (MINIMUMF, SPIRVFMinOp);
408- INT_OR_FLOAT_CASE (MAXF, SPIRVFMaxOp);
409- INT_OR_FLOAT_CASE (MINF, SPIRVFMinOp);
410433 INT_OR_FLOAT_CASE (MINUI, SPIRVUMinOp);
411434 INT_OR_FLOAT_CASE (MINSI, SPIRVSMinOp);
412435 INT_OR_FLOAT_CASE (MAXUI, SPIRVUMaxOp);
@@ -416,7 +439,51 @@ struct VectorReductionPattern final
416439 case vector::CombiningKind::OR:
417440 case vector::CombiningKind::XOR:
418441 return rewriter.notifyMatchFailure (reduceOp, " unimplemented" );
442+ default :
443+ return rewriter.notifyMatchFailure (reduceOp, " not handled here" );
419444 }
445+ #undef INT_AND_FLOAT_CASE
446+ #undef INT_OR_FLOAT_CASE
447+ }
448+
449+ rewriter.replaceOp (reduceOp, result);
450+ return success ();
451+ }
452+ };
453+
454+ template <typename SPIRVFMaxOp, typename SPIRVFMinOp>
455+ struct VectorReductionFloatMinMax final
456+ : OpConversionPattern<vector::ReductionOp> {
457+ using OpConversionPattern::OpConversionPattern;
458+
459+ LogicalResult
460+ matchAndRewrite (vector::ReductionOp reduceOp, OpAdaptor adaptor,
461+ ConversionPatternRewriter &rewriter) const override {
462+ auto reductionInfo =
463+ getReductionInfo (reduceOp, adaptor, rewriter, *getTypeConverter ());
464+ if (failed (reductionInfo))
465+ return failure ();
466+
467+ auto [resultType, extractedElements] = *reductionInfo;
468+ Location loc = reduceOp->getLoc ();
469+ Value result = extractedElements.front ();
470+ for (Value next : llvm::drop_begin (extractedElements)) {
471+ switch (reduceOp.getKind ()) {
472+
473+ #define INT_OR_FLOAT_CASE (kind, fop ) \
474+ case vector::CombiningKind::kind: \
475+ result = rewriter.create <fop>(loc, resultType, result, next); \
476+ break
477+
478+ INT_OR_FLOAT_CASE (MAXIMUMF, SPIRVFMaxOp);
479+ INT_OR_FLOAT_CASE (MINIMUMF, SPIRVFMinOp);
480+ INT_OR_FLOAT_CASE (MAXF, SPIRVFMaxOp);
481+ INT_OR_FLOAT_CASE (MINF, SPIRVFMinOp);
482+
483+ default :
484+ return rewriter.notifyMatchFailure (reduceOp, " not handled here" );
485+ }
486+ #undef INT_OR_FLOAT_CASE
420487 }
421488
422489 rewriter.replaceOp (reduceOp, result);
@@ -674,13 +741,14 @@ struct VectorReductionToDotProd final : OpRewritePattern<vector::ReductionOp> {
674741};
675742
676743} // namespace
677- #define CL_MAX_MIN_OPS \
678- spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \
679- spirv::CLSMaxOp, spirv::CLSMinOp
744+ #define CL_INT_MAX_MIN_OPS \
745+ spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp
746+
747+ #define GL_INT_MAX_MIN_OPS \
748+ spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp
680749
681- #define GL_MAX_MIN_OPS \
682- spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \
683- spirv::GLSMaxOp, spirv::GLSMinOp
750+ #define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp
751+ #define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp
684752
685753void mlir::populateVectorToSPIRVPatterns (SPIRVTypeConverter &typeConverter,
686754 RewritePatternSet &patterns) {
@@ -689,8 +757,10 @@ void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
689757 VectorExtractElementOpConvert, VectorExtractOpConvert,
690758 VectorExtractStridedSliceOpConvert, VectorFmaOpConvert<spirv::GLFmaOp>,
691759 VectorFmaOpConvert<spirv::CLFmaOp>, VectorInsertElementOpConvert,
692- VectorInsertOpConvert, VectorReductionPattern<GL_MAX_MIN_OPS>,
693- VectorReductionPattern<CL_MAX_MIN_OPS>, VectorShapeCast,
760+ VectorInsertOpConvert, VectorReductionPattern<GL_INT_MAX_MIN_OPS>,
761+ VectorReductionPattern<CL_INT_MAX_MIN_OPS>,
762+ VectorReductionFloatMinMax<CL_FLOAT_MAX_MIN_OPS>,
763+ VectorReductionFloatMinMax<GL_FLOAT_MAX_MIN_OPS>, VectorShapeCast,
694764 VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert,
695765 VectorSplatPattern, VectorLoadOpConverter, VectorStoreOpConverter>(
696766 typeConverter, patterns.getContext ());
0 commit comments