Skip to content

Commit 824225d

Browse files
committed
address farzon
1 parent f1aac85 commit 824225d

File tree

3 files changed

+182
-27
lines changed

3 files changed

+182
-27
lines changed

llvm/lib/Target/SPIRV/SPIRVCombinerHelper.cpp

Lines changed: 71 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ void SPIRVCombinerHelper::applySPIRVDistance(MachineInstr &MI) const {
7171
/// (vXf32 (g_intrinsic faceforward
7272
/// (vXf32 N) (vXf32 I) (vXf32 Ng)))
7373
///
74-
/// This only works for Vulkan targets.
74+
/// This only works for Vulkan shader targets.
7575
///
7676
bool SPIRVCombinerHelper::matchSelectToFaceForward(MachineInstr &MI) const {
7777
if (!STI.isShader())
@@ -88,8 +88,11 @@ bool SPIRVCombinerHelper::matchSelectToFaceForward(MachineInstr &MI) const {
8888
CmpInst::Predicate Pred;
8989
if (!mi_match(CondReg, MRI,
9090
m_GFCmp(m_Pred(Pred), m_Reg(DotReg), m_Reg(CondZeroReg))) ||
91-
Pred != CmpInst::FCMP_OLT)
92-
return false;
91+
!(Pred == CmpInst::FCMP_OLT || Pred == CmpInst::FCMP_ULT)) {
92+
if (!(Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT))
93+
return false;
94+
std::swap(DotReg, CondZeroReg);
95+
}
9396

9497
// Check if FCMP is a comparison between a dot product and 0.
9598
MachineInstr *DotInstr = MRI.getVRegDef(DotReg);
@@ -109,29 +112,43 @@ bool SPIRVCombinerHelper::matchSelectToFaceForward(MachineInstr &MI) const {
109112
return false;
110113

111114
// Check if select's false operand is the negation of the true operand.
112-
auto AreNegatedConstants = [&](Register TrueReg, Register FalseReg) {
113-
const ConstantFP *TrueVal, *FalseVal;
114-
if (!mi_match(TrueReg, MRI, m_GFCst(TrueVal)) ||
115-
!mi_match(FalseReg, MRI, m_GFCst(FalseVal)))
115+
auto AreNegatedConstantsOrSplats = [&](Register TrueReg, Register FalseReg) {
116+
std::optional<FPValueAndVReg> TrueVal, FalseVal;
117+
if (!mi_match(TrueReg, MRI, m_GFCstOrSplat(TrueVal)) ||
118+
!mi_match(FalseReg, MRI, m_GFCstOrSplat(FalseVal)))
116119
return false;
117-
APFloat TrueValNegated = TrueVal->getValue();
120+
APFloat TrueValNegated = TrueVal->Value;
118121
TrueValNegated.changeSign();
119-
return FalseVal->getValue().compare(TrueValNegated) == APFloat::cmpEqual;
122+
return FalseVal->Value.compare(TrueValNegated) == APFloat::cmpEqual;
120123
};
121124

122-
if (!mi_match(FalseReg, MRI, m_GFNeg(m_SpecificReg(TrueReg))) &&
123-
!mi_match(TrueReg, MRI, m_GFNeg(m_SpecificReg(FalseReg)))) {
124-
// Check if they're constant opposites.
125+
if (!mi_match(TrueReg, MRI, m_GFNeg(m_SpecificReg(FalseReg))) &&
126+
!mi_match(FalseReg, MRI, m_GFNeg(m_SpecificReg(TrueReg)))) {
127+
std::optional<FPValueAndVReg> MulConstant;
125128
MachineInstr *TrueInstr = MRI.getVRegDef(TrueReg);
126129
MachineInstr *FalseInstr = MRI.getVRegDef(FalseReg);
127130
if (TrueInstr->getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
128131
FalseInstr->getOpcode() == TargetOpcode::G_BUILD_VECTOR &&
129132
TrueInstr->getNumOperands() == FalseInstr->getNumOperands()) {
130133
for (unsigned I = 1; I < TrueInstr->getNumOperands(); ++I)
131-
if (!AreNegatedConstants(TrueInstr->getOperand(I).getReg(),
132-
FalseInstr->getOperand(I).getReg()))
134+
if (!AreNegatedConstantsOrSplats(TrueInstr->getOperand(I).getReg(),
135+
FalseInstr->getOperand(I).getReg()))
133136
return false;
134-
} else if (!AreNegatedConstants(TrueReg, FalseReg))
137+
} else if (mi_match(TrueReg, MRI,
138+
m_GFMul(m_SpecificReg(FalseReg),
139+
m_GFCstOrSplat(MulConstant))) ||
140+
mi_match(FalseReg, MRI,
141+
m_GFMul(m_SpecificReg(TrueReg),
142+
m_GFCstOrSplat(MulConstant))) ||
143+
mi_match(TrueReg, MRI,
144+
m_GFMul(m_GFCstOrSplat(MulConstant),
145+
m_SpecificReg(FalseReg))) ||
146+
mi_match(FalseReg, MRI,
147+
m_GFMul(m_GFCstOrSplat(MulConstant),
148+
m_SpecificReg(TrueReg)))) {
149+
if (!MulConstant || !MulConstant->Value.isExactlyValue(-1.0))
150+
return false;
151+
} else if (!AreNegatedConstantsOrSplats(TrueReg, FalseReg))
135152
return false;
136153
}
137154

@@ -140,17 +157,28 @@ bool SPIRVCombinerHelper::matchSelectToFaceForward(MachineInstr &MI) const {
140157

141158
void SPIRVCombinerHelper::applySPIRVFaceForward(MachineInstr &MI) const {
142159
// Extract the operands for N, I, and Ng from the match criteria.
143-
Register CondReg, TrueReg, DotReg, DotOperand1, DotOperand2;
144-
if (!mi_match(MI.getOperand(0).getReg(), MRI,
145-
m_GISelect(m_Reg(CondReg), m_Reg(TrueReg), m_Reg())))
146-
return;
147-
if (!mi_match(CondReg, MRI, m_GFCmp(m_Pred(), m_Reg(DotReg), m_Reg())))
148-
return;
160+
Register CondReg = MI.getOperand(1).getReg();
161+
MachineInstr *CondInstr = MRI.getVRegDef(CondReg);
162+
Register DotReg = CondInstr->getOperand(2).getReg();
163+
CmpInst::Predicate Pred = cast<GFCmp>(CondInstr)->getCond();
164+
if (Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT)
165+
DotReg = CondInstr->getOperand(3).getReg();
149166
MachineInstr *DotInstr = MRI.getVRegDef(DotReg);
150-
if (!mi_match(DotReg, MRI, m_GFMul(m_Reg(DotOperand1), m_Reg(DotOperand2)))) {
167+
Register DotOperand1, DotOperand2;
168+
if (DotInstr->getOpcode() == TargetOpcode::G_FMUL) {
169+
DotOperand1 = DotInstr->getOperand(1).getReg();
170+
DotOperand2 = DotInstr->getOperand(2).getReg();
171+
} else {
151172
DotOperand1 = DotInstr->getOperand(2).getReg();
152173
DotOperand2 = DotInstr->getOperand(3).getReg();
153174
}
175+
Register TrueReg = MI.getOperand(2).getReg();
176+
Register FalseReg = MI.getOperand(3).getReg();
177+
MachineInstr *TrueInstr = MRI.getVRegDef(TrueReg);
178+
if (TrueInstr->getOpcode() == TargetOpcode::G_FNEG ||
179+
TrueInstr->getOpcode() == TargetOpcode::G_FMUL)
180+
std::swap(TrueReg, FalseReg);
181+
MachineInstr *FalseInstr = MRI.getVRegDef(FalseReg);
154182

155183
Register ResultReg = MI.getOperand(0).getReg();
156184
Builder.setInstrAndDebugLoc(MI);
@@ -159,5 +187,25 @@ void SPIRVCombinerHelper::applySPIRVFaceForward(MachineInstr &MI) const {
159187
.addUse(DotOperand1) // I
160188
.addUse(DotOperand2); // Ng
161189

162-
MI.eraseFromParent();
190+
SPIRVGlobalRegistry *GR =
191+
MI.getMF()->getSubtarget<SPIRVSubtarget>().getSPIRVGlobalRegistry();
192+
auto RemoveAllUses = [&](Register Reg) {
193+
SmallVector<MachineInstr *, 4> UsesToErase;
194+
for (auto &UseMI : MRI.use_instructions(Reg))
195+
UsesToErase.push_back(&UseMI);
196+
197+
// calling eraseFromParent to early invalidates the iterator.
198+
for (auto *MIToErase : UsesToErase)
199+
MIToErase->eraseFromParent();
200+
};
201+
202+
RemoveAllUses(CondReg); // remove all uses of FCMP Result
203+
GR->invalidateMachineInstr(CondInstr);
204+
CondInstr->eraseFromParent(); // remove FCMP instruction
205+
RemoveAllUses(DotReg); // remove all uses of spv_fdot/G_FMUL Result
206+
GR->invalidateMachineInstr(DotInstr);
207+
DotInstr->eraseFromParent(); // remove spv_fdot/G_FMUL instruction
208+
RemoveAllUses(FalseReg);
209+
GR->invalidateMachineInstr(FalseInstr);
210+
FalseInstr->eraseFromParent();
163211
}

llvm/test/CodeGen/SPIRV/GlobalISel/InstCombine/prelegalizercombiner-select-to-faceforward.mir

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ tracksRegLiveness: true
9999
legalized: true
100100
body: |
101101
bb.1.entry:
102-
; CHECK-LABEL: name: faceforward_instcombine_float4
102+
; CHECK-LABEL: name: faceforward_instcombine_float4_constants
103103
; CHECK-NOT: %10:_(s32) = G_FCONSTANT float 0.000000e+00
104104
; CHECK-NOT: %16:_(s32) = G_FCONSTANT float -1.000000e+00
105105
; CHECK-NOT: %15:_(<4 x s32>) = G_BUILD_VECTOR %16:_(s32), %16:_(s32), %16:_(s32), %16:_(s32)
@@ -130,4 +130,70 @@ body: |
130130
%11:_(s1) = G_FCMP floatpred(olt), %9:_(s32), %10:_
131131
%12:id(<4 x s32>) = G_SELECT %11:_(s1), %13:_, %15:_
132132
OpReturnValue %12:id(<4 x s32>)
133+
---
134+
name: faceforward_instcombine_float4_false_fmul
135+
tracksRegLiveness: true
136+
legalized: true
137+
body: |
138+
bb.1.entry:
139+
; CHECK-LABEL: name: faceforward_instcombine_float4_false_fmul
140+
; CHECK-NOT: %10:_(s32) = G_FCONSTANT float 0.000000e+00
141+
; CHECK-NOT: %13:_(s32) = G_FCONSTANT float -1.000000e+00
142+
; CHECK-NOT: %12:_(<4 x s32>) = G_BUILD_VECTOR %13:_(s32), %13:_(s32), %13:_(s32), %13:_(s32)
143+
; CHECK-NOT: %9:_(s32) = G_INTRINSIC intrinsic(@llvm.spv.fdot), %1:vfid(<4 x s32>), %2:vfid(<4 x s32>)
144+
; CHECK-NOT: %11:_(s1) = G_FCMP floatpred(olt), %9:_(s32), %10:_
145+
; CHECK-NOT: %14:_(<4 x s32>) = G_FMUL %0:vfid, %12:_
146+
; CHECK-NOT: %15:id(<4 x s32>) = G_SELECT %11:_(s1), %0:vfid, %14:_
147+
; CHECK: %13:id(<4 x s32>) = G_INTRINSIC intrinsic(@llvm.spv.faceforward), %3(<4 x s32>), %4(<4 x s32>), %5(<4 x s32>)
148+
; CHECK: OpReturnValue %13(<4 x s32>)
149+
%4:type(s64) = OpTypeVector %3:type(s64), 4
150+
%6:type(s64) = OpTypeFunction %4:type(s64), %4:type(s64), %4:type(s64), %4:type(s64)
151+
%3:type(s64) = OpTypeFloat 32
152+
OpName %0:vfid(<4 x s32>), 97
153+
OpName %1:vfid(<4 x s32>), 98
154+
OpName %2:vfid(<4 x s32>), 99
155+
%5:iid(s64) = OpFunction %4:type(s64), 0, %6:type(s64)
156+
%0:vfid(<4 x s32>) = OpFunctionParameter %4:type(s64)
157+
%1:vfid(<4 x s32>) = OpFunctionParameter %4:type(s64)
158+
%2:vfid(<4 x s32>) = OpFunctionParameter %4:type(s64)
159+
OpName %5:iid(s64), 1701011814, 2003988326, 1600418401, 1953721961, 1651339107, 1600482921, 1634692198, 1717515380, 1702063201, 1970103903, 108
160+
%10:_(s32) = G_FCONSTANT float 0.000000e+00
161+
%13:_(s32) = G_FCONSTANT float -1.000000e+00
162+
%12:_(<4 x s32>) = G_BUILD_VECTOR %13:_(s32), %13:_(s32), %13:_(s32), %13:_(s32)
163+
%9:_(s32) = G_INTRINSIC intrinsic(@llvm.spv.fdot), %1:vfid(<4 x s32>), %2:vfid(<4 x s32>)
164+
%11:_(s1) = G_FCMP floatpred(olt), %9:_(s32), %10:_
165+
%14:_(<4 x s32>) = G_FMUL %0:vfid, %12:_
166+
%15:id(<4 x s32>) = G_SELECT %11:_(s1), %0:vfid, %14:_
167+
OpReturnValue %15:id(<4 x s32>)
168+
---
169+
name: faceforward_instcombine_float4_ogt
170+
tracksRegLiveness: true
171+
legalized: true
172+
body: |
173+
bb.1.entry:
174+
; CHECK-LABEL: name: faceforward_instcombine_float4
175+
; CHECK-NOT: %10:_(s32) = G_FCONSTANT float 0.000000e+00
176+
; CHECK-NOT: %9:_(s32) = G_INTRINSIC intrinsic(@llvm.spv.fdot), %1:vfid(<4 x s32>), %2:vfid(<4 x s32>)
177+
; CHECK-NOT: %11:_(s1) = G_FCMP floatpred(ogt), %10:_(s32), %9:_
178+
; CHECK-NOT: %12:_(<4 x s32>) = G_FNEG %0:vfid
179+
; CHECK-NOT: %13:id(<4 x s32>) = G_SELECT %11:_(s1), %12:_, %0:vfid
180+
; CHECK: %11:id(<4 x s32>) = G_INTRINSIC intrinsic(@llvm.spv.faceforward), %3(<4 x s32>), %4(<4 x s32>), %5(<4 x s32>)
181+
; CHECK: OpReturnValue %11(<4 x s32>)
182+
%4:type(s64) = OpTypeVector %3:type(s64), 4
183+
%6:type(s64) = OpTypeFunction %4:type(s64), %4:type(s64), %4:type(s64), %4:type(s64)
184+
%3:type(s64) = OpTypeFloat 32
185+
OpName %0:vfid(<4 x s32>), 97
186+
OpName %1:vfid(<4 x s32>), 98
187+
OpName %2:vfid(<4 x s32>), 99
188+
%5:iid(s64) = OpFunction %4:type(s64), 0, %6:type(s64)
189+
%0:vfid(<4 x s32>) = OpFunctionParameter %4:type(s64)
190+
%1:vfid(<4 x s32>) = OpFunctionParameter %4:type(s64)
191+
%2:vfid(<4 x s32>) = OpFunctionParameter %4:type(s64)
192+
OpName %5:iid(s64), 1701011814, 2003988326, 1600418401, 1953721961, 1651339107, 1600482921, 1634692198, 1868510324, 29799
193+
%10:_(s32) = G_FCONSTANT float 0.000000e+00
194+
%9:_(s32) = G_INTRINSIC intrinsic(@llvm.spv.fdot), %1:vfid(<4 x s32>), %2:vfid(<4 x s32>)
195+
%11:_(s1) = G_FCMP floatpred(ogt), %10:_(s32), %9:_
196+
%12:_(<4 x s32>) = G_FNEG %0:vfid
197+
%13:id(<4 x s32>) = G_SELECT %11:_(s1), %12:_, %0:vfid
198+
OpReturnValue %13:id(<4 x s32>)
133199

llvm/test/CodeGen/SPIRV/hlsl-intrinsics/faceforward.ll

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ entry:
7777
%spv.fdot = call float @llvm.spv.fdot.v4f32(<4 x float> %b, <4 x float> %c)
7878
%fcmp = fcmp olt float %spv.fdot, 0.000000e+00
7979
%fneg = fneg <4 x float> %a
80-
%select = select i1 %fcmp, <4 x float> %a, <4 x float> %fneg
80+
%select = select i1 %fcmp, <4 x float> %fneg, <4 x float> %a
8181
ret <4 x float> %select
8282
}
8383

@@ -89,7 +89,7 @@ entry:
8989
; CHECK: %[[#arg2:]] = OpFunctionParameter %[[#float_32]]
9090
; CHECK: %[[#]] = OpExtInst %[[#float_32]] %[[#op_ext_glsl]] FaceForward %[[#]] %[[#arg1]] %[[#arg2]]
9191
%fmul = fmul float %b, %c
92-
%fcmp = fcmp olt float %fmul, 0.000000e+00
92+
%fcmp = fcmp olt float %fmul, -0.000000e+00
9393
%select = select i1 %fcmp, float 1.000000e+00, float -1.000000e+00
9494
ret float %select
9595
}
@@ -107,7 +107,48 @@ entry:
107107
ret <4 x float> %select
108108
}
109109

110-
; The other fucntions are the test, but a entry point is required to have a valid SPIR-V module.
110+
define internal noundef <4 x float> @faceforward_instcombine_float4_splat(<4 x float> noundef %a, <4 x float> noundef %b, <4 x float> noundef %c) {
111+
entry:
112+
; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#]]
113+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
114+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
115+
; CHECK: %[[#arg2:]] = OpFunctionParameter %[[#vec4_float_32]]
116+
; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] FaceForward %[[#]] %[[#arg1]] %[[#arg2]]
117+
%spv.fdot = call float @llvm.spv.fdot.v4f32(<4 x float> %b, <4 x float> %c)
118+
%fcmp = fcmp olt float %spv.fdot, 0.000000e+00
119+
%select = select i1 %fcmp, <4 x float> splat (float 2.500000e+00), <4 x float> splat (float -2.500000e+00)
120+
ret <4 x float> %select
121+
}
122+
123+
define internal noundef <4 x float> @faceforward_instcombine_float4_false_fmul(<4 x float> noundef %a, <4 x float> noundef %b, <4 x float> noundef %c) {
124+
entry:
125+
; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#]]
126+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
127+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
128+
; CHECK: %[[#arg2:]] = OpFunctionParameter %[[#vec4_float_32]]
129+
; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] FaceForward %[[#]] %[[#arg1]] %[[#arg2]]
130+
%spv.fdot = call float @llvm.spv.fdot.v4f32(<4 x float> %b, <4 x float> %c)
131+
%fcmp = fcmp olt float %spv.fdot, 0.000000e+00
132+
%fneg = fmul <4 x float> %a, <float -1.000000e+00, float -1.000000e+00, float -1.000000e+00, float -1.000000e+00>
133+
%select = select i1 %fcmp, <4 x float> %a, <4 x float> %fneg
134+
ret <4 x float> %select
135+
}
136+
137+
define internal noundef <4 x float> @faceforward_instcombine_float4_ogt(<4 x float> noundef %a, <4 x float> noundef %b, <4 x float> noundef %c) {
138+
entry:
139+
; CHECK: %[[#]] = OpFunction %[[#vec4_float_32]] None %[[#]]
140+
; CHECK: %[[#arg0:]] = OpFunctionParameter %[[#vec4_float_32]]
141+
; CHECK: %[[#arg1:]] = OpFunctionParameter %[[#vec4_float_32]]
142+
; CHECK: %[[#arg2:]] = OpFunctionParameter %[[#vec4_float_32]]
143+
; CHECK: %[[#]] = OpExtInst %[[#vec4_float_32]] %[[#op_ext_glsl]] FaceForward %[[#]] %[[#arg1]] %[[#arg2]]
144+
%spv.fdot = call float @llvm.spv.fdot.v4f32(<4 x float> %b, <4 x float> %c)
145+
%fcmp = fcmp ogt float 0.000000e+00, %spv.fdot
146+
%fneg = fneg <4 x float> %a
147+
%select = select i1 %fcmp, <4 x float> %fneg, <4 x float> %a
148+
ret <4 x float> %select
149+
}
150+
151+
; The other functions are the test, but a entry point is required to have a valid SPIR-V module.
111152
define void @main() #1 {
112153
entry:
113154
ret void

0 commit comments

Comments
 (0)