diff --git a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp index 9b29179f36871..1d46d9503e976 100644 --- a/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp +++ b/mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp @@ -397,9 +397,12 @@ struct VectorReductionPattern final break #define INT_OR_FLOAT_CASE(kind, fop) \ - case vector::CombiningKind::kind: \ - result = rewriter.create(loc, resultType, result, next); \ - break + case vector::CombiningKind::kind: { \ + fop op = rewriter.create(loc, resultType, result, next); \ + result = this->generateActionForOp(rewriter, loc, resultType, op, \ + vector::CombiningKind::kind); \ + break; \ + } INT_AND_FLOAT_CASE(ADD, IAddOp, FAddOp); INT_AND_FLOAT_CASE(MUL, IMulOp, FMulOp); @@ -422,6 +425,51 @@ struct VectorReductionPattern final rewriter.replaceOp(reduceOp, result); return success(); } + +private: + enum class Action { Nothing, PropagateNaN, PropagateNonNaN }; + + template + Action getActionForOp(vector::CombiningKind kind) const { + constexpr bool isCLOp = std::is_same_v || + std::is_same_v; + switch (kind) { + case vector::CombiningKind::MINIMUMF: + case vector::CombiningKind::MAXIMUMF: + return Action::PropagateNaN; + case vector::CombiningKind::MINF: + case vector::CombiningKind::MAXF: + // CL ops already have the same semantic for NaNs as MINF/MAXF + // GL ops have undefined semantics for NaNs, so we need to explicitly + // propagate the non-NaN values + return isCLOp ? Action::Nothing : Action::PropagateNonNaN; + default: + return Action::Nothing; + } + } + + template + Value generateActionForOp(ConversionPatternRewriter &rewriter, + mlir::Location loc, Type resultType, Op op, + vector::CombiningKind kind) const { + Action action = getActionForOp(kind); + + if (action == Action::Nothing) { + return op; + } + + Value lhsIsNan = rewriter.create(loc, op.getLhs()); + Value rhsIsNan = rewriter.create(loc, op.getRhs()); + + Value select1 = rewriter.create( + loc, resultType, lhsIsNan, + action == Action::PropagateNaN ? op.getLhs() : op.getRhs(), op); + Value select2 = rewriter.create( + loc, resultType, rhsIsNan, + action == Action::PropagateNaN ? op.getRhs() : op.getLhs(), select1); + + return select2; + } }; class VectorSplatPattern final : public OpConversionPattern { diff --git a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir index eba763eab9c29..91836e556147b 100644 --- a/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir +++ b/mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir @@ -56,9 +56,21 @@ func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector< // CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> // CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> // CHECK: %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]] -// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]] -// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]] -// CHECK: return %[[MAX2]] +// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32 +// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32 +// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32 +// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32 +// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[SELECT1]], %[[S2]] +// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32 +// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32 +// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32 +// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32 +// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[SELECT3]], %[[S]] +// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32 +// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32 +// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32 +// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32 +// CHECK: return %[[SELECT5]] func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 { %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 return %reduce : f32 @@ -70,11 +82,51 @@ func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 { // CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> // CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> // CHECK: %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]] +// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32 +// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32 +// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32 +// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32 +// CHECK: %[[MIN1:.+]] = spirv.CL.fmin %[[SELECT1]], %[[S2]] +// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32 +// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32 +// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32 +// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32 +// CHECK: %[[MIN2:.+]] = spirv.CL.fmin %[[SELECT3]], %[[S]] +// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32 +// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32 +// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32 +// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32 +// CHECK: return %[[SELECT5]] +func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// CHECK-LABEL: func @cl_reduction_maxf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]] +// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]] +// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]] +// CHECK: return %[[MAX2]] +func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// CHECK-LABEL: func @cl_reduction_minf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]] // CHECK: %[[MIN1:.+]] = spirv.CL.fmin %[[MIN0]], %[[S2]] // CHECK: %[[MIN2:.+]] = spirv.CL.fmin %[[MIN1]], %[[S]] // CHECK: return %[[MIN2]] -func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 { - %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 +func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 return %reduce : f32 } @@ -522,9 +574,21 @@ func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 { // CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> // CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> // CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]] -// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[MAX0]], %[[S2]] -// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[MAX1]], %[[S]] -// CHECK: return %[[MAX2]] +// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32 +// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32 +// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32 +// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32 +// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]] +// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32 +// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32 +// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32 +// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32 +// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]] +// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32 +// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32 +// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32 +// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32 +// CHECK: return %[[SELECT5]] func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 { %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 return %reduce : f32 @@ -532,15 +596,55 @@ func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 { // ----- +// CHECK-LABEL: func @reduction_maxf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]] +// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32 +// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32 +// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MAX0]] : i1, f32 +// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32 +// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]] +// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32 +// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32 +// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MAX1]] : i1, f32 +// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32 +// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]] +// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32 +// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32 +// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MAX2]] : i1, f32 +// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32 +// CHECK: return %[[SELECT5]] +func.func @reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// ----- + // CHECK-LABEL: func @reduction_minimumf // CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) // CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> // CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> // CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> // CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]] -// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[MIN0]], %[[S2]] -// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[MIN1]], %[[S]] -// CHECK: return %[[MIN2]] +// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32 +// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32 +// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32 +// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32 +// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[SELECT1]], %[[S2]] +// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32 +// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32 +// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32 +// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32 +// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[SELECT3]], %[[S]] +// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32 +// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32 +// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32 +// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32 +// CHECK: return %[[SELECT5]] func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 { %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 return %reduce : f32 @@ -548,6 +652,34 @@ func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 { // ----- +// CHECK-LABEL: func @reduction_minf +// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32) +// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32> +// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32> +// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32> +// CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]] +// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32 +// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32 +// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MIN0]] : i1, f32 +// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32 +// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[SELECT1]], %[[S2]] +// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32 +// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32 +// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MIN1]] : i1, f32 +// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32 +// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[SELECT3]], %[[S]] +// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32 +// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32 +// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MIN2]] : i1, f32 +// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32 +// CHECK: return %[[SELECT5]] +func.func @reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 { + %reduce = vector.reduction , %v, %s : vector<3xf32> into f32 + return %reduce : f32 +} + +// ----- + // CHECK-LABEL: func @reduction_maxsi // CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32) // CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>