Skip to content

Commit a2f5832

Browse files
rotaterighttstellar
authored andcommitted
[x86] invert a vector select IR canonicalization with a binop identity constant
This is an intentionally limited/different form of D90113. That patch bravely tries to generalize folds where we pull a binop into the arms of a select: N0 + (Cond ? 0 : FVal) --> Cond ? N0 : (N0 + FVal) ...but it is not universally profitable. This is the inverse of IR canonicalization as discussed in D113442. We know that this transform is not entirely profitable even within x86, so we only handle x86 vector fadd/fsub as a 1st step. The intent is to prevent AVX512 regressions as mentioned in D113442. The plan is to port this to DAGCombiner (so it will eventually look more like D90113) and add more types/cases in pieces with many more tests to verify that we are seeing improvements. Differential Revision: https://reviews.llvm.org/D118644 (cherry picked from commit 6592bce)
1 parent 4f624dd commit a2f5832

File tree

3 files changed

+102
-26
lines changed

3 files changed

+102
-26
lines changed

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48920,6 +48920,83 @@ static SDValue combineFaddCFmul(SDNode *N, SelectionDAG &DAG,
4892048920
return DAG.getBitcast(VT, CFmul);
4892148921
}
4892248922

48923+
/// This inverts a canonicalization in IR that replaces a variable select arm
48924+
/// with an identity constant. Codegen improves if we re-use the variable
48925+
/// operand rather than load a constant. This can also be converted into a
48926+
/// masked vector operation if the target supports it.
48927+
static SDValue foldSelectWithIdentityConstant(SDNode *N, SelectionDAG &DAG,
48928+
bool ShouldCommuteOperands) {
48929+
// Match a select as operand 1. The identity constant that we are looking for
48930+
// is only valid as operand 1 of a non-commutative binop.
48931+
SDValue N0 = N->getOperand(0);
48932+
SDValue N1 = N->getOperand(1);
48933+
if (ShouldCommuteOperands)
48934+
std::swap(N0, N1);
48935+
48936+
// TODO: Should this apply to scalar select too?
48937+
if (!N1.hasOneUse() || N1.getOpcode() != ISD::VSELECT)
48938+
return SDValue();
48939+
48940+
unsigned Opcode = N->getOpcode();
48941+
EVT VT = N->getValueType(0);
48942+
SDValue Cond = N1.getOperand(0);
48943+
SDValue TVal = N1.getOperand(1);
48944+
SDValue FVal = N1.getOperand(2);
48945+
48946+
// TODO: This (and possibly the entire function) belongs in a
48947+
// target-independent location with target hooks.
48948+
// TODO: The cases should match with IR's ConstantExpr::getBinOpIdentity().
48949+
// TODO: With fast-math (NSZ), allow the opposite-sign form of zero?
48950+
auto isIdentityConstantForOpcode = [](unsigned Opcode, SDValue V) {
48951+
if (ConstantFPSDNode *C = isConstOrConstSplatFP(V)) {
48952+
switch (Opcode) {
48953+
case ISD::FADD: // X + -0.0 --> X
48954+
return C->isZero() && C->isNegative();
48955+
case ISD::FSUB: // X - 0.0 --> X
48956+
return C->isZero() && !C->isNegative();
48957+
}
48958+
}
48959+
return false;
48960+
};
48961+
48962+
// This transform increases uses of N0, so freeze it to be safe.
48963+
// binop N0, (vselect Cond, IDC, FVal) --> vselect Cond, N0, (binop N0, FVal)
48964+
if (isIdentityConstantForOpcode(Opcode, TVal)) {
48965+
SDValue F0 = DAG.getFreeze(N0);
48966+
SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, FVal, N->getFlags());
48967+
return DAG.getSelect(SDLoc(N), VT, Cond, F0, NewBO);
48968+
}
48969+
// binop N0, (vselect Cond, TVal, IDC) --> vselect Cond, (binop N0, TVal), N0
48970+
if (isIdentityConstantForOpcode(Opcode, FVal)) {
48971+
SDValue F0 = DAG.getFreeze(N0);
48972+
SDValue NewBO = DAG.getNode(Opcode, SDLoc(N), VT, F0, TVal, N->getFlags());
48973+
return DAG.getSelect(SDLoc(N), VT, Cond, NewBO, F0);
48974+
}
48975+
48976+
return SDValue();
48977+
}
48978+
48979+
static SDValue combineBinopWithSelect(SDNode *N, SelectionDAG &DAG,
48980+
const X86Subtarget &Subtarget) {
48981+
// TODO: This is too general. There are cases where pre-AVX512 codegen would
48982+
// benefit. The transform may also be profitable for scalar code.
48983+
if (!Subtarget.hasAVX512())
48984+
return SDValue();
48985+
48986+
if (!Subtarget.hasVLX() && !N->getValueType(0).is512BitVector())
48987+
return SDValue();
48988+
48989+
if (SDValue Sel = foldSelectWithIdentityConstant(N, DAG, false))
48990+
return Sel;
48991+
48992+
const TargetLowering &TLI = DAG.getTargetLoweringInfo();
48993+
if (TLI.isCommutativeBinOp(N->getOpcode()))
48994+
if (SDValue Sel = foldSelectWithIdentityConstant(N, DAG, true))
48995+
return Sel;
48996+
48997+
return SDValue();
48998+
}
48999+
4892349000
/// Do target-specific dag combines on floating-point adds/subs.
4892449001
static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG,
4892549002
const X86Subtarget &Subtarget) {
@@ -48929,6 +49006,9 @@ static SDValue combineFaddFsub(SDNode *N, SelectionDAG &DAG,
4892949006
if (SDValue COp = combineFaddCFmul(N, DAG, Subtarget))
4893049007
return COp;
4893149008

49009+
if (SDValue Sel = combineBinopWithSelect(N, DAG, Subtarget))
49010+
return Sel;
49011+
4893249012
return SDValue();
4893349013
}
4893449014

llvm/test/CodeGen/X86/avx512fp16-arith-intrinsics.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ define <32 x half> @test_int_x86_avx512fp16_maskz_sub_ph_512(<32 x half> %src, <
8383
; CHECK: # %bb.0:
8484
; CHECK-NEXT: kmovd %edi, %k1
8585
; CHECK-NEXT: vsubph %zmm2, %zmm1, %zmm0 {%k1} {z}
86-
; CHECK-NEXT: vsubph (%rsi), %zmm1, %zmm1 {%k1} {z}
87-
; CHECK-NEXT: vsubph %zmm1, %zmm0, %zmm0
86+
; CHECK-NEXT: vsubph (%rsi), %zmm1, %zmm1
87+
; CHECK-NEXT: vsubph %zmm1, %zmm0, %zmm0 {%k1}
8888
; CHECK-NEXT: retq
8989
%mask = bitcast i32 %msk to <32 x i1>
9090
%val = load <32 x half>, <32 x half>* %ptr

llvm/test/CodeGen/X86/vector-bo-select.ll

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,8 @@ define <4 x float> @fadd_v4f32(<4 x i1> %b, <4 x float> noundef %x, <4 x float>
2727
; AVX512VL: # %bb.0:
2828
; AVX512VL-NEXT: vpslld $31, %xmm0, %xmm0
2929
; AVX512VL-NEXT: vptestmd %xmm0, %xmm0, %k1
30-
; AVX512VL-NEXT: vbroadcastss {{.*#+}} xmm0 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
31-
; AVX512VL-NEXT: vmovaps %xmm2, %xmm0 {%k1}
32-
; AVX512VL-NEXT: vaddps %xmm0, %xmm1, %xmm0
30+
; AVX512VL-NEXT: vaddps %xmm2, %xmm1, %xmm1 {%k1}
31+
; AVX512VL-NEXT: vmovaps %xmm1, %xmm0
3332
; AVX512VL-NEXT: retq
3433
%s = select <4 x i1> %b, <4 x float> %y, <4 x float> <float -0.0, float -0.0, float -0.0, float -0.0>
3534
%r = fadd <4 x float> %x, %s
@@ -62,9 +61,8 @@ define <8 x float> @fadd_v8f32_commute(<8 x i1> %b, <8 x float> noundef %x, <8 x
6261
; AVX512VL-NEXT: vpmovsxwd %xmm0, %ymm0
6362
; AVX512VL-NEXT: vpslld $31, %ymm0, %ymm0
6463
; AVX512VL-NEXT: vptestmd %ymm0, %ymm0, %k1
65-
; AVX512VL-NEXT: vbroadcastss {{.*#+}} ymm0 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
66-
; AVX512VL-NEXT: vmovaps %ymm2, %ymm0 {%k1}
67-
; AVX512VL-NEXT: vaddps %ymm1, %ymm0, %ymm0
64+
; AVX512VL-NEXT: vaddps %ymm2, %ymm1, %ymm1 {%k1}
65+
; AVX512VL-NEXT: vmovaps %ymm1, %ymm0
6866
; AVX512VL-NEXT: retq
6967
%s = select <8 x i1> %b, <8 x float> %y, <8 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>
7068
%r = fadd <8 x float> %s, %x
@@ -92,8 +90,8 @@ define <16 x float> @fadd_v16f32_swap(<16 x i1> %b, <16 x float> noundef %x, <16
9290
; AVX512-NEXT: vpmovsxbd %xmm0, %zmm0
9391
; AVX512-NEXT: vpslld $31, %zmm0, %zmm0
9492
; AVX512-NEXT: vptestmd %zmm0, %zmm0, %k1
95-
; AVX512-NEXT: vbroadcastss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2 {%k1}
9693
; AVX512-NEXT: vaddps %zmm2, %zmm1, %zmm0
94+
; AVX512-NEXT: vmovaps %zmm1, %zmm0 {%k1}
9795
; AVX512-NEXT: retq
9896
%s = select <16 x i1> %b, <16 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>, <16 x float> %y
9997
%r = fadd <16 x float> %x, %s
@@ -121,8 +119,8 @@ define <16 x float> @fadd_v16f32_commute_swap(<16 x i1> %b, <16 x float> noundef
121119
; AVX512-NEXT: vpmovsxbd %xmm0, %zmm0
122120
; AVX512-NEXT: vpslld $31, %zmm0, %zmm0
123121
; AVX512-NEXT: vptestmd %zmm0, %zmm0, %k1
124-
; AVX512-NEXT: vbroadcastss {{\.?LCPI[0-9]+_[0-9]+}}(%rip), %zmm2 {%k1}
125-
; AVX512-NEXT: vaddps %zmm1, %zmm2, %zmm0
122+
; AVX512-NEXT: vaddps %zmm2, %zmm1, %zmm0
123+
; AVX512-NEXT: vmovaps %zmm1, %zmm0 {%k1}
126124
; AVX512-NEXT: retq
127125
%s = select <16 x i1> %b, <16 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>, <16 x float> %y
128126
%r = fadd <16 x float> %s, %x
@@ -152,14 +150,16 @@ define <4 x float> @fsub_v4f32(<4 x i1> %b, <4 x float> noundef %x, <4 x float>
152150
; AVX512VL: # %bb.0:
153151
; AVX512VL-NEXT: vpslld $31, %xmm0, %xmm0
154152
; AVX512VL-NEXT: vptestmd %xmm0, %xmm0, %k1
155-
; AVX512VL-NEXT: vmovaps %xmm2, %xmm0 {%k1} {z}
156-
; AVX512VL-NEXT: vsubps %xmm0, %xmm1, %xmm0
153+
; AVX512VL-NEXT: vsubps %xmm2, %xmm1, %xmm1 {%k1}
154+
; AVX512VL-NEXT: vmovaps %xmm1, %xmm0
157155
; AVX512VL-NEXT: retq
158156
%s = select <4 x i1> %b, <4 x float> %y, <4 x float> zeroinitializer
159157
%r = fsub <4 x float> %x, %s
160158
ret <4 x float> %r
161159
}
162160

161+
; negative test - fsub is not commutative; there is no identity constant for operand 0
162+
163163
define <8 x float> @fsub_v8f32_commute(<8 x i1> %b, <8 x float> noundef %x, <8 x float> noundef %y) {
164164
; AVX2-LABEL: fsub_v8f32_commute:
165165
; AVX2: # %bb.0:
@@ -214,15 +214,17 @@ define <16 x float> @fsub_v16f32_swap(<16 x i1> %b, <16 x float> noundef %x, <16
214214
; AVX512: # %bb.0:
215215
; AVX512-NEXT: vpmovsxbd %xmm0, %zmm0
216216
; AVX512-NEXT: vpslld $31, %zmm0, %zmm0
217-
; AVX512-NEXT: vptestnmd %zmm0, %zmm0, %k1
218-
; AVX512-NEXT: vmovaps %zmm2, %zmm0 {%k1} {z}
219-
; AVX512-NEXT: vsubps %zmm0, %zmm1, %zmm0
217+
; AVX512-NEXT: vptestmd %zmm0, %zmm0, %k1
218+
; AVX512-NEXT: vsubps %zmm2, %zmm1, %zmm0
219+
; AVX512-NEXT: vmovaps %zmm1, %zmm0 {%k1}
220220
; AVX512-NEXT: retq
221221
%s = select <16 x i1> %b, <16 x float> zeroinitializer, <16 x float> %y
222222
%r = fsub <16 x float> %x, %s
223223
ret <16 x float> %r
224224
}
225225

226+
; negative test - fsub is not commutative; there is no identity constant for operand 0
227+
226228
define <16 x float> @fsub_v16f32_commute_swap(<16 x i1> %b, <16 x float> noundef %x, <16 x float> noundef %y) {
227229
; AVX2-LABEL: fsub_v16f32_commute_swap:
228230
; AVX2: # %bb.0:
@@ -570,9 +572,7 @@ define <8 x float> @fadd_v8f32_cast_cond(i8 noundef zeroext %pb, <8 x float> nou
570572
; AVX512VL-LABEL: fadd_v8f32_cast_cond:
571573
; AVX512VL: # %bb.0:
572574
; AVX512VL-NEXT: kmovw %edi, %k1
573-
; AVX512VL-NEXT: vbroadcastss {{.*#+}} ymm2 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
574-
; AVX512VL-NEXT: vmovaps %ymm1, %ymm2 {%k1}
575-
; AVX512VL-NEXT: vaddps %ymm2, %ymm0, %ymm0
575+
; AVX512VL-NEXT: vaddps %ymm1, %ymm0, %ymm0 {%k1}
576576
; AVX512VL-NEXT: retq
577577
%b = bitcast i8 %pb to <8 x i1>
578578
%s = select <8 x i1> %b, <8 x float> %y, <8 x float> <float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0, float -0.0>
@@ -636,9 +636,7 @@ define <8 x double> @fadd_v8f64_cast_cond(i8 noundef zeroext %pb, <8 x double> n
636636
; AVX512-LABEL: fadd_v8f64_cast_cond:
637637
; AVX512: # %bb.0:
638638
; AVX512-NEXT: kmovw %edi, %k1
639-
; AVX512-NEXT: vbroadcastsd {{.*#+}} zmm2 = [-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0,-0.0E+0]
640-
; AVX512-NEXT: vmovapd %zmm1, %zmm2 {%k1}
641-
; AVX512-NEXT: vaddpd %zmm2, %zmm0, %zmm0
639+
; AVX512-NEXT: vaddpd %zmm1, %zmm0, %zmm0 {%k1}
642640
; AVX512-NEXT: retq
643641
%b = bitcast i8 %pb to <8 x i1>
644642
%s = select <8 x i1> %b, <8 x double> %y, <8 x double> <double -0.0, double -0.0, double -0.0, double -0.0, double -0.0, double -0.0, double -0.0, double -0.0>
@@ -709,8 +707,7 @@ define <8 x float> @fsub_v8f32_cast_cond(i8 noundef zeroext %pb, <8 x float> nou
709707
; AVX512VL-LABEL: fsub_v8f32_cast_cond:
710708
; AVX512VL: # %bb.0:
711709
; AVX512VL-NEXT: kmovw %edi, %k1
712-
; AVX512VL-NEXT: vmovaps %ymm1, %ymm1 {%k1} {z}
713-
; AVX512VL-NEXT: vsubps %ymm1, %ymm0, %ymm0
710+
; AVX512VL-NEXT: vsubps %ymm1, %ymm0, %ymm0 {%k1}
714711
; AVX512VL-NEXT: retq
715712
%b = bitcast i8 %pb to <8 x i1>
716713
%s = select <8 x i1> %b, <8 x float> %y, <8 x float> zeroinitializer
@@ -775,8 +772,7 @@ define <8 x double> @fsub_v8f64_cast_cond(i8 noundef zeroext %pb, <8 x double> n
775772
; AVX512-LABEL: fsub_v8f64_cast_cond:
776773
; AVX512: # %bb.0:
777774
; AVX512-NEXT: kmovw %edi, %k1
778-
; AVX512-NEXT: vmovapd %zmm1, %zmm1 {%k1} {z}
779-
; AVX512-NEXT: vsubpd %zmm1, %zmm0, %zmm0
775+
; AVX512-NEXT: vsubpd %zmm1, %zmm0, %zmm0 {%k1}
780776
; AVX512-NEXT: retq
781777
%b = bitcast i8 %pb to <8 x i1>
782778
%s = select <8 x i1> %b, <8 x double> %y, <8 x double> zeroinitializer

0 commit comments

Comments
 (0)