From a736e6b5dfff09a8a24e17e130b3b8314ad8c825 Mon Sep 17 00:00:00 2001 From: Daniil Dudkin Date: Thu, 21 Sep 2023 22:47:21 +0300 Subject: [PATCH 1/2] [mlir][spirv] Split codegen for float min/max reductions and others (NFC) This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. There are two types of min/max operations for floating-point numbers: `minf`/`maxf` and `minimumf`/`maximumf`. The code generation for these operations should differ from that of other vector reduction kinds. This difference arises because CL and GL operations for floating-point min and max do not have the same semantics when handling NaNs. Therefore, we must enforce the desired semantics with additional ops. However, since the code generation for floating-point min/max operations shares the same functionality as extracting values for the vector, we have decided to refactor the existing code using the CRTP pattern. This change does not alter the actual behavior of the code and is necessary for future fixes to the codegen for floating-point min/max operations. --- .../VectorToSPIRV/VectorToSPIRV.cpp | 130 ++++++++++++++---- 1 file changed, 101 insertions(+), 29 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 9b29179f36871..c4c0497c2d1f0 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -20,6 +20,7 @@ #include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Location.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" @@ -351,15 +352,13 @@ struct VectorInsertStridedSliceOpConvert final } }; -template -struct VectorReductionPattern final - : public OpConversionPattern { +template +struct VectorReductionPatternBase : OpConversionPattern { using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(vector::ReductionOp reduceOp, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { + ConversionPatternRewriter &rewriter) const final { Type resultType = typeConverter->convertType(reduceOp.getType()); if (!resultType) return failure(); @@ -368,9 +367,22 @@ struct VectorReductionPattern final if (!srcVectorType || srcVectorType.getRank() != 1) return rewriter.notifyMatchFailure(reduceOp, "not 1-D vector source"); - // Extract all elements. + SmallVector extractedElements = + extractAllElements(reduceOp, adaptor, srcVectorType, rewriter); + + const auto &self = static_cast(*this); + + return self.reduceExtracted(reduceOp, extractedElements, resultType, + rewriter); + } + +private: + SmallVector + extractAllElements(vector::ReductionOp reduceOp, OpAdaptor adaptor, + VectorType srcVectorType, + ConversionPatternRewriter &rewriter) const { int numElements = srcVectorType.getDimSize(0); - SmallVector values; + SmallVector values; values.reserve(numElements + (adaptor.getAcc() != nullptr)); Location loc = reduceOp.getLoc(); for (int i = 0; i < numElements; ++i) { @@ -381,9 +393,26 @@ struct VectorReductionPattern final if (Value acc = adaptor.getAcc()) values.push_back(acc); - // Reduce them. - Value result = values.front(); - for (Value next : llvm::ArrayRef(values).drop_front()) { + return values; + } +}; + +#define VECTOR_REDUCTION_BASE \ + VectorReductionPatternBase> +template +struct VectorReductionPattern final : VECTOR_REDUCTION_BASE { + using Base = VECTOR_REDUCTION_BASE; + using Base::Base; + + LogicalResult reduceExtracted(vector::ReductionOp reduceOp, + ArrayRef extractedElements, + Type resultType, + ConversionPatternRewriter &rewriter) const { + mlir::Location loc = reduceOp->getLoc(); + Value result = extractedElements.front(); + for (Value next : llvm::ArrayRef(extractedElements).drop_front()) { switch (reduceOp.getKind()) { #define INT_AND_FLOAT_CASE(kind, iop, fop) \ @@ -403,10 +432,6 @@ struct VectorReductionPattern final INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp); INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp); - INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp); - INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp); - INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp); - INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp); INT_OR_FLOAT_CASE(MINUI, SPIRVUMinOp); INT_OR_FLOAT_CASE(MINSI, SPIRVSMinOp); INT_OR_FLOAT_CASE(MAXUI, SPIRVUMaxOp); @@ -416,6 +441,8 @@ struct VectorReductionPattern final case vector::CombiningKind::OR: case vector::CombiningKind::XOR: return rewriter.notifyMatchFailure(reduceOp, "unimplemented"); + default: + return rewriter.notifyMatchFailure(reduceOp, "not handled here"); } } @@ -423,6 +450,48 @@ struct VectorReductionPattern final return success(); } }; +#undef VECTOR_REDUCTION_BASE +#undef INT_AND_FLOAT_CASE +#undef INT_OR_FLOAT_CASE + +#define MIN_MAX_PATTERN_BASE \ + VectorReductionPatternBase< \ + VectorReductionFloatMinMax> +template +struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE { + using Base = MIN_MAX_PATTERN_BASE; + using Base::Base; + + LogicalResult reduceExtracted(vector::ReductionOp reduceOp, + ArrayRef extractedElements, + Type resultType, + ConversionPatternRewriter &rewriter) const { + mlir::Location loc = reduceOp->getLoc(); + Value result = extractedElements.front(); + for (Value next : llvm::ArrayRef(extractedElements).drop_front()) { + switch (reduceOp.getKind()) { + +#define INT_OR_FLOAT_CASE(kind, fop) \ + case vector::CombiningKind::kind: \ + result = rewriter.create(loc, resultType, result, next); \ + break + + INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp); + INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp); + INT_OR_FLOAT_CASE(MAXF, SPIRVFMaxOp); + INT_OR_FLOAT_CASE(MINF, SPIRVFMinOp); + + default: + return rewriter.notifyMatchFailure(reduceOp, "not handled here"); + } + } + + rewriter.replaceOp(reduceOp, result); + return success(); + } +}; +#undef MIN_MAX_PATTERN_BASE +#undef INT_OR_FLOAT_CASE class VectorSplatPattern final : public OpConversionPattern { public: @@ -604,25 +673,28 @@ struct VectorReductionToDotProd final : OpRewritePattern { }; } // namespace -#define CL_MAX_MIN_OPS \ - spirv::CLFMaxOp, spirv::CLFMinOp, spirv::CLUMaxOp, spirv::CLUMinOp, \ - spirv::CLSMaxOp, spirv::CLSMinOp +#define CL_INT_MAX_MIN_OPS \ + spirv::CLUMaxOp, spirv::CLUMinOp, spirv::CLSMaxOp, spirv::CLSMinOp + +#define GL_INT_MAX_MIN_OPS \ + spirv::GLUMaxOp, spirv::GLUMinOp, spirv::GLSMaxOp, spirv::GLSMinOp -#define GL_MAX_MIN_OPS \ - spirv::GLFMaxOp, spirv::GLFMinOp, spirv::GLUMaxOp, spirv::GLUMinOp, \ - spirv::GLSMaxOp, spirv::GLSMinOp +#define CL_FLOAT_MAX_MIN_OPS spirv::CLFMaxOp, spirv::CLFMinOp +#define GL_FLOAT_MAX_MIN_OPS spirv::GLFMaxOp, spirv::GLFMinOp void mlir::populateVectorToSPIRVPatterns(SPIRVTypeConverter &typeConverter, RewritePatternSet &patterns) { - patterns.add, - VectorFmaOpConvert, VectorInsertElementOpConvert, - VectorInsertOpConvert, VectorReductionPattern, - VectorReductionPattern, VectorShapeCast, - VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, - VectorSplatPattern>(typeConverter, patterns.getContext()); + patterns.add< + VectorBitcastConvert, VectorBroadcastConvert, + VectorExtractElementOpConvert, VectorExtractOpConvert, + VectorExtractStridedSliceOpConvert, VectorFmaOpConvert, + VectorFmaOpConvert, VectorInsertElementOpConvert, + VectorInsertOpConvert, VectorReductionPattern, + VectorReductionPattern, + VectorReductionFloatMinMax, + VectorReductionFloatMinMax, VectorShapeCast, + VectorInsertStridedSliceOpConvert, VectorShuffleOpConvert, + VectorSplatPattern>(typeConverter, patterns.getContext()); } void mlir::populateVectorReductionToSPIRVDotProductPatterns( From 08284c15ac0379706e390718aa78657bbc986224 Mon Sep 17 00:00:00 2001 From: Daniil Dudkin Date: Fri, 13 Oct 2023 22:29:58 +0300 Subject: [PATCH 2/2] [mlir][spirv] Fix vector reduction lowerings for FP min/max This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics. This patch addresses tasks 2.4 and 2.5 of the RFC. --- .../VectorToSPIRV/VectorToSPIRV.cpp | 55 ++++++- .../VectorToSPIRV/vector-to-spirv.mlir | 154 ++++++++++++++++-- 2 files changed, 195 insertions(+), 14 deletions(-) diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index c4c0497c2d1f0..040fa69e2e9f2 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -29,6 +29,7 @@ #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" #include #include @@ -472,9 +473,12 @@ struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE { switch (reduceOp.getKind()) { #define INT_OR_FLOAT_CASE(kind, fop) \ - case vector::CombiningKind::kind: \ - result = rewriter.create(loc, resultType, result, next); \ - break + case vector::CombiningKind::kind: { \ + fop op = rewriter.create(loc, resultType, result, next); \ + result = this->generateActionForOp(rewriter, loc, resultType, op, \ + vector::CombiningKind::kind); \ + break; \ + } INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp); INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp); @@ -489,6 +493,51 @@ struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE { rewriter.replaceOp(reduceOp, result); return success(); } + +private: + enum class Action { Nothing, PropagateNaN, PropagateNonNaN }; + + template + Action getActionForOp(vector::CombiningKind kind) const { + constexpr bool isCLOp = std::is_same_v || + std::is_same_v; + switch (kind) { + case vector::CombiningKind::MINIMUMF: + case vector::CombiningKind::MAXIMUMF: + return Action::PropagateNaN; + case vector::CombiningKind::MINF: + case vector::CombiningKind::MAXF: + // CL ops already have the same semantic for NaNs as MINF/MAXF + // GL ops have undefined semantics for NaNs, so we need to explicitly + // propagate the non-NaN values + return isCLOp ? Action::Nothing : Action::PropagateNonNaN; + default: + llvm_unreachable("Unexpected case for the switch"); + } + } + + template + Value generateActionForOp(ConversionPatternRewriter &rewriter, + mlir::Location loc, Type resultType, Op op, + vector::CombiningKind kind) const { + Action action = getActionForOp(kind); + + if (action == Action::Nothing) { + return op; + } + + Value lhsIsNan = rewriter.create(loc, op.getLhs()); + Value rhsIsNan = rewriter.create(loc, op.getRhs()); + + Value select1 = rewriter.create( + loc, resultType, lhsIsNan, + action == Action::PropagateNaN ? op.getLhs() : op.getRhs(), op); + Value select2 = rewriter.create( + loc, resultType, rhsIsNan, + action == Action::PropagateNaN ? op.getRhs() : op.getLhs(), select1); + + return select2; + } }; #undef MIN_MAX_PATTERN_BASE #undef INT_OR_FLOAT_CASE diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index eba763eab9c29..91836e556147b 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -56,9 +56,21 @@ func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector< // CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> // CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> // CHECK: %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]] -// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]] -// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]] -// CHECK: return %[[MAX2]] +// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32 +// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32 +// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32 +// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32 +// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[SELECT1]], %[[S2]] +// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32 +// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32 +// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32 +// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32 +// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[SELECT3]], %[[S]] +// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32 +// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32 +// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32 +// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32 +// CHECK: return %[[SELECT5]] func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 { %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 return %reduce : f32 @@ -70,11 +82,51 @@ func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 { // CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> // CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> // CHECK: %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]] +// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32 +// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32 +// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32 +// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32 +// CHECK: %[[MIN1:.+]] = spirv.CL.fmin %[[SELECT1]], %[[S2]] +// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32 +// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32 +// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32 +// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32 +// CHECK: %[[MIN2:.+]] = spirv.CL.fmin %[[SELECT3]], %[[S]] +// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32 +// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32 +// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32 +// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32 +// CHECK: return %[[SELECT5]] +func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// CHECK-LABEL: func @cl_reduction_maxf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]] +// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]] +// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]] +// CHECK: return %[[MAX2]] +func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// CHECK-LABEL: func @cl_reduction_minf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]] // CHECK: %[[MIN1:.+]] = spirv.CL.fmin %[[MIN0]], %[[S2]] // CHECK: %[[MIN2:.+]] = spirv.CL.fmin %[[MIN1]], %[[S]] // CHECK: return %[[MIN2]] -func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 { - %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 +func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 return %reduce : f32 } @@ -522,9 +574,21 @@ func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 { // CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> // CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> // CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]] -// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[MAX0]], %[[S2]] -// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[MAX1]], %[[S]] -// CHECK: return %[[MAX2]] +// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32 +// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32 +// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32 +// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32 +// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]] +// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32 +// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32 +// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32 +// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32 +// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]] +// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32 +// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32 +// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32 +// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32 +// CHECK: return %[[SELECT5]] func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 { %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 return %reduce : f32 @@ -532,15 +596,55 @@ func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 { // ----- +// CHECK-LABEL: func @reduction_maxf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]] +// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32 +// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32 +// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MAX0]] : i1, f32 +// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32 +// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]] +// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32 +// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32 +// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MAX1]] : i1, f32 +// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32 +// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]] +// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32 +// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32 +// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MAX2]] : i1, f32 +// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32 +// CHECK: return %[[SELECT5]] +func.func @reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// ----- + // CHECK-LABEL: func @reduction_minimumf // CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) // CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> // CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> // CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> // CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]] -// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[MIN0]], %[[S2]] -// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[MIN1]], %[[S]] -// CHECK: return %[[MIN2]] +// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32 +// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32 +// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32 +// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32 +// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[SELECT1]], %[[S2]] +// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32 +// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32 +// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32 +// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32 +// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[SELECT3]], %[[S]] +// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32 +// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32 +// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32 +// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32 +// CHECK: return %[[SELECT5]] func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 { %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 return %reduce : f32 @@ -548,6 +652,34 @@ func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 { // ----- +// CHECK-LABEL: func @reduction_minf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]] +// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32 +// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32 +// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MIN0]] : i1, f32 +// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32 +// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[SELECT1]], %[[S2]] +// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32 +// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32 +// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MIN1]] : i1, f32 +// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32 +// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[SELECT3]], %[[S]] +// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32 +// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32 +// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MIN2]] : i1, f32 +// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32 +// CHECK: return %[[SELECT5]] +func.func @reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// ----- + // CHECK-LABEL: func @reduction_maxsi // CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) // CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>