Skip to content

Commit 08284c1

Browse files
committed
[mlir][spirv] Fix vector reduction lowerings for FP min/max
This patch is part of a larger initiative aimed at fixing floating-point `max` and `min` operations in MLIR: https://discourse.llvm.org/t/rfc-fix-floating-point-max-and-min-operations-in-mlir/72671. This commit fixes the vector reduction lowerings for the floating-point min/max kinds by implementing additional generation of operations that propagate semantics. This patch addresses tasks 2.4 and 2.5 of the RFC.
1 parent a736e6b commit 08284c1

File tree

2 files changed

+195
-14
lines changed

2 files changed

+195
-14
lines changed

mlir/lib/Conversion/VectorToSPIRV/VectorToSPIRV.cpp

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
#include "llvm/ADT/ArrayRef.h"
3030
#include "llvm/ADT/STLExtras.h"
3131
#include "llvm/ADT/SmallVectorExtras.h"
32+
#include "llvm/Support/ErrorHandling.h"
3233
#include "llvm/Support/FormatVariadic.h"
3334
#include <cassert>
3435
#include <cstdint>
@@ -472,9 +473,12 @@ struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE {
472473
switch (reduceOp.getKind()) {
473474

474475
#define INT_OR_FLOAT_CASE(kind, fop) \
475-
case vector::CombiningKind::kind: \
476-
result = rewriter.create<fop>(loc, resultType, result, next); \
477-
break
476+
case vector::CombiningKind::kind: { \
477+
fop op = rewriter.create<fop>(loc, resultType, result, next); \
478+
result = this->generateActionForOp(rewriter, loc, resultType, op, \
479+
vector::CombiningKind::kind); \
480+
break; \
481+
}
478482

479483
INT_OR_FLOAT_CASE(MAXIMUMF, SPIRVFMaxOp);
480484
INT_OR_FLOAT_CASE(MINIMUMF, SPIRVFMinOp);
@@ -489,6 +493,51 @@ struct VectorReductionFloatMinMax final : MIN_MAX_PATTERN_BASE {
489493
rewriter.replaceOp(reduceOp, result);
490494
return success();
491495
}
496+
497+
private:
498+
enum class Action { Nothing, PropagateNaN, PropagateNonNaN };
499+
500+
template <typename Op>
501+
Action getActionForOp(vector::CombiningKind kind) const {
502+
constexpr bool isCLOp = std::is_same_v<Op, spirv::CLFMaxOp> ||
503+
std::is_same_v<Op, spirv::CLFMinOp>;
504+
switch (kind) {
505+
case vector::CombiningKind::MINIMUMF:
506+
case vector::CombiningKind::MAXIMUMF:
507+
return Action::PropagateNaN;
508+
case vector::CombiningKind::MINF:
509+
case vector::CombiningKind::MAXF:
510+
// CL ops already have the same semantic for NaNs as MINF/MAXF
511+
// GL ops have undefined semantics for NaNs, so we need to explicitly
512+
// propagate the non-NaN values
513+
return isCLOp ? Action::Nothing : Action::PropagateNonNaN;
514+
default:
515+
llvm_unreachable("Unexpected case for the switch");
516+
}
517+
}
518+
519+
template <typename Op>
520+
Value generateActionForOp(ConversionPatternRewriter &rewriter,
521+
mlir::Location loc, Type resultType, Op op,
522+
vector::CombiningKind kind) const {
523+
Action action = getActionForOp<Op>(kind);
524+
525+
if (action == Action::Nothing) {
526+
return op;
527+
}
528+
529+
Value lhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getLhs());
530+
Value rhsIsNan = rewriter.create<spirv::IsNanOp>(loc, op.getRhs());
531+
532+
Value select1 = rewriter.create<spirv::SelectOp>(
533+
loc, resultType, lhsIsNan,
534+
action == Action::PropagateNaN ? op.getLhs() : op.getRhs(), op);
535+
Value select2 = rewriter.create<spirv::SelectOp>(
536+
loc, resultType, rhsIsNan,
537+
action == Action::PropagateNaN ? op.getRhs() : op.getLhs(), select1);
538+
539+
return select2;
540+
}
492541
};
493542
#undef MIN_MAX_PATTERN_BASE
494543
#undef INT_OR_FLOAT_CASE

mlir/test/Conversion/VectorToSPIRV/vector-to-spirv.mlir

Lines changed: 143 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,21 @@ func.func @cl_fma_size1_vector(%a: vector<1xf32>, %b: vector<1xf32>, %c: vector<
5656
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
5757
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
5858
// CHECK: %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
59-
// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
60-
// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
61-
// CHECK: return %[[MAX2]]
59+
// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
60+
// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
61+
// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
62+
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
63+
// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[SELECT1]], %[[S2]]
64+
// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
65+
// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
66+
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
67+
// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
68+
// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[SELECT3]], %[[S]]
69+
// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
70+
// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
71+
// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
72+
// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
73+
// CHECK: return %[[SELECT5]]
6274
func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
6375
%reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
6476
return %reduce : f32
@@ -70,11 +82,51 @@ func.func @cl_reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
7082
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
7183
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
7284
// CHECK: %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
85+
// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
86+
// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
87+
// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32
88+
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
89+
// CHECK: %[[MIN1:.+]] = spirv.CL.fmin %[[SELECT1]], %[[S2]]
90+
// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
91+
// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
92+
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32
93+
// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
94+
// CHECK: %[[MIN2:.+]] = spirv.CL.fmin %[[SELECT3]], %[[S]]
95+
// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
96+
// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
97+
// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32
98+
// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
99+
// CHECK: return %[[SELECT5]]
100+
func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
101+
%reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
102+
return %reduce : f32
103+
}
104+
105+
// CHECK-LABEL: func @cl_reduction_maxf
106+
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
107+
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
108+
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
109+
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
110+
// CHECK: %[[MAX0:.+]] = spirv.CL.fmax %[[S0]], %[[S1]]
111+
// CHECK: %[[MAX1:.+]] = spirv.CL.fmax %[[MAX0]], %[[S2]]
112+
// CHECK: %[[MAX2:.+]] = spirv.CL.fmax %[[MAX1]], %[[S]]
113+
// CHECK: return %[[MAX2]]
114+
func.func @cl_reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
115+
%reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
116+
return %reduce : f32
117+
}
118+
119+
// CHECK-LABEL: func @cl_reduction_minf
120+
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
121+
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
122+
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
123+
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
124+
// CHECK: %[[MIN0:.+]] = spirv.CL.fmin %[[S0]], %[[S1]]
73125
// CHECK: %[[MIN1:.+]] = spirv.CL.fmin %[[MIN0]], %[[S2]]
74126
// CHECK: %[[MIN2:.+]] = spirv.CL.fmin %[[MIN1]], %[[S]]
75127
// CHECK: return %[[MIN2]]
76-
func.func @cl_reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
77-
%reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
128+
func.func @cl_reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
129+
%reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
78130
return %reduce : f32
79131
}
80132

@@ -522,32 +574,112 @@ func.func @reduction_mul(%v : vector<3xf32>, %s: f32) -> f32 {
522574
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
523575
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
524576
// CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
525-
// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[MAX0]], %[[S2]]
526-
// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[MAX1]], %[[S]]
527-
// CHECK: return %[[MAX2]]
577+
// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
578+
// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
579+
// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MAX0]] : i1, f32
580+
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
581+
// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
582+
// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
583+
// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
584+
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MAX1]] : i1, f32
585+
// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
586+
// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
587+
// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
588+
// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
589+
// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MAX2]] : i1, f32
590+
// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
591+
// CHECK: return %[[SELECT5]]
528592
func.func @reduction_maximumf(%v : vector<3xf32>, %s: f32) -> f32 {
529593
%reduce = vector.reduction <maximumf>, %v, %s : vector<3xf32> into f32
530594
return %reduce : f32
531595
}
532596

533597
// -----
534598

599+
// CHECK-LABEL: func @reduction_maxf
600+
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
601+
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
602+
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
603+
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
604+
// CHECK: %[[MAX0:.+]] = spirv.GL.FMax %[[S0]], %[[S1]]
605+
// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
606+
// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
607+
// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MAX0]] : i1, f32
608+
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32
609+
// CHECK: %[[MAX1:.+]] = spirv.GL.FMax %[[SELECT1]], %[[S2]]
610+
// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
611+
// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
612+
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MAX1]] : i1, f32
613+
// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32
614+
// CHECK: %[[MAX2:.+]] = spirv.GL.FMax %[[SELECT3]], %[[S]]
615+
// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
616+
// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
617+
// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MAX2]] : i1, f32
618+
// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32
619+
// CHECK: return %[[SELECT5]]
620+
func.func @reduction_maxf(%v : vector<3xf32>, %s: f32) -> f32 {
621+
%reduce = vector.reduction <maxf>, %v, %s : vector<3xf32> into f32
622+
return %reduce : f32
623+
}
624+
625+
// -----
626+
535627
// CHECK-LABEL: func @reduction_minimumf
536628
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
537629
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
538630
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
539631
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
540632
// CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]]
541-
// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[MIN0]], %[[S2]]
542-
// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[MIN1]], %[[S]]
543-
// CHECK: return %[[MIN2]]
633+
// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
634+
// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
635+
// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S0]], %[[MIN0]] : i1, f32
636+
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S1]], %[[SELECT0]] : i1, f32
637+
// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[SELECT1]], %[[S2]]
638+
// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
639+
// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
640+
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[SELECT1]], %[[MIN1]] : i1, f32
641+
// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[S2]], %[[SELECT2]] : i1, f32
642+
// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[SELECT3]], %[[S]]
643+
// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
644+
// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
645+
// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[SELECT3]], %[[MIN2]] : i1, f32
646+
// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[S]], %[[SELECT4]] : i1, f32
647+
// CHECK: return %[[SELECT5]]
544648
func.func @reduction_minimumf(%v : vector<3xf32>, %s: f32) -> f32 {
545649
%reduce = vector.reduction <minimumf>, %v, %s : vector<3xf32> into f32
546650
return %reduce : f32
547651
}
548652

549653
// -----
550654

655+
// CHECK-LABEL: func @reduction_minf
656+
// CHECK-SAME: (%[[V:.+]]: vector<3xf32>, %[[S:.+]]: f32)
657+
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xf32>
658+
// CHECK: %[[S1:.+]] = spirv.CompositeExtract %[[V]][1 : i32] : vector<3xf32>
659+
// CHECK: %[[S2:.+]] = spirv.CompositeExtract %[[V]][2 : i32] : vector<3xf32>
660+
// CHECK: %[[MIN0:.+]] = spirv.GL.FMin %[[S0]], %[[S1]]
661+
// CHECK: %[[ISNAN0:.+]] = spirv.IsNan %[[S0]] : f32
662+
// CHECK: %[[ISNAN1:.+]] = spirv.IsNan %[[S1]] : f32
663+
// CHECK: %[[SELECT0:.+]] = spirv.Select %[[ISNAN0]], %[[S1]], %[[MIN0]] : i1, f32
664+
// CHECK: %[[SELECT1:.+]] = spirv.Select %[[ISNAN1]], %[[S0]], %[[SELECT0]] : i1, f32
665+
// CHECK: %[[MIN1:.+]] = spirv.GL.FMin %[[SELECT1]], %[[S2]]
666+
// CHECK: %[[ISNAN2:.+]] = spirv.IsNan %[[SELECT1]] : f32
667+
// CHECK: %[[ISNAN3:.+]] = spirv.IsNan %[[S2]] : f32
668+
// CHECK: %[[SELECT2:.+]] = spirv.Select %[[ISNAN2]], %[[S2]], %[[MIN1]] : i1, f32
669+
// CHECK: %[[SELECT3:.+]] = spirv.Select %[[ISNAN3]], %[[SELECT1]], %[[SELECT2]] : i1, f32
670+
// CHECK: %[[MIN2:.+]] = spirv.GL.FMin %[[SELECT3]], %[[S]]
671+
// CHECK: %[[ISNAN4:.+]] = spirv.IsNan %[[SELECT3]] : f32
672+
// CHECK: %[[ISNAN5:.+]] = spirv.IsNan %[[S]] : f32
673+
// CHECK: %[[SELECT4:.+]] = spirv.Select %[[ISNAN4]], %[[S]], %[[MIN2]] : i1, f32
674+
// CHECK: %[[SELECT5:.+]] = spirv.Select %[[ISNAN5]], %[[SELECT3]], %[[SELECT4]] : i1, f32
675+
// CHECK: return %[[SELECT5]]
676+
func.func @reduction_minf(%v : vector<3xf32>, %s: f32) -> f32 {
677+
%reduce = vector.reduction <minf>, %v, %s : vector<3xf32> into f32
678+
return %reduce : f32
679+
}
680+
681+
// -----
682+
551683
// CHECK-LABEL: func @reduction_maxsi
552684
// CHECK-SAME: (%[[V:.+]]: vector<3xi32>, %[[S:.+]]: i32)
553685
// CHECK: %[[S0:.+]] = spirv.CompositeExtract %[[V]][0 : i32] : vector<3xi32>

0 commit comments

Comments
 (0)