@@ -71,7 +71,7 @@ void SPIRVCombinerHelper::applySPIRVDistance(MachineInstr &MI) const {
7171// / (vXf32 (g_intrinsic faceforward
7272// / (vXf32 N) (vXf32 I) (vXf32 Ng)))
7373// /
74- // / This only works for Vulkan targets.
74+ // / This only works for Vulkan shader targets.
7575// /
7676bool SPIRVCombinerHelper::matchSelectToFaceForward (MachineInstr &MI) const {
7777 if (!STI.isShader ())
@@ -88,8 +88,11 @@ bool SPIRVCombinerHelper::matchSelectToFaceForward(MachineInstr &MI) const {
8888 CmpInst::Predicate Pred;
8989 if (!mi_match (CondReg, MRI,
9090 m_GFCmp (m_Pred (Pred), m_Reg (DotReg), m_Reg (CondZeroReg))) ||
91- Pred != CmpInst::FCMP_OLT)
92- return false ;
91+ !(Pred == CmpInst::FCMP_OLT || Pred == CmpInst::FCMP_ULT)) {
92+ if (!(Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT))
93+ return false ;
94+ std::swap (DotReg, CondZeroReg);
95+ }
9396
9497 // Check if FCMP is a comparison between a dot product and 0.
9598 MachineInstr *DotInstr = MRI.getVRegDef (DotReg);
@@ -109,29 +112,43 @@ bool SPIRVCombinerHelper::matchSelectToFaceForward(MachineInstr &MI) const {
109112 return false ;
110113
111114 // Check if select's false operand is the negation of the true operand.
112- auto AreNegatedConstants = [&](Register TrueReg, Register FalseReg) {
113- const ConstantFP * TrueVal, * FalseVal;
114- if (!mi_match (TrueReg, MRI, m_GFCst (TrueVal)) ||
115- !mi_match (FalseReg, MRI, m_GFCst (FalseVal)))
115+ auto AreNegatedConstantsOrSplats = [&](Register TrueReg, Register FalseReg) {
116+ std::optional<FPValueAndVReg> TrueVal, FalseVal;
117+ if (!mi_match (TrueReg, MRI, m_GFCstOrSplat (TrueVal)) ||
118+ !mi_match (FalseReg, MRI, m_GFCstOrSplat (FalseVal)))
116119 return false ;
117- APFloat TrueValNegated = TrueVal->getValue () ;
120+ APFloat TrueValNegated = TrueVal->Value ;
118121 TrueValNegated.changeSign ();
119- return FalseVal->getValue () .compare (TrueValNegated) == APFloat::cmpEqual;
122+ return FalseVal->Value .compare (TrueValNegated) == APFloat::cmpEqual;
120123 };
121124
122- if (!mi_match (FalseReg , MRI, m_GFNeg (m_SpecificReg (TrueReg ))) &&
123- !mi_match (TrueReg , MRI, m_GFNeg (m_SpecificReg (FalseReg )))) {
124- // Check if they're constant opposites.
125+ if (!mi_match (TrueReg , MRI, m_GFNeg (m_SpecificReg (FalseReg ))) &&
126+ !mi_match (FalseReg , MRI, m_GFNeg (m_SpecificReg (TrueReg )))) {
127+ std::optional<FPValueAndVReg> MulConstant;
125128 MachineInstr *TrueInstr = MRI.getVRegDef (TrueReg);
126129 MachineInstr *FalseInstr = MRI.getVRegDef (FalseReg);
127130 if (TrueInstr->getOpcode () == TargetOpcode::G_BUILD_VECTOR &&
128131 FalseInstr->getOpcode () == TargetOpcode::G_BUILD_VECTOR &&
129132 TrueInstr->getNumOperands () == FalseInstr->getNumOperands ()) {
130133 for (unsigned I = 1 ; I < TrueInstr->getNumOperands (); ++I)
131- if (!AreNegatedConstants (TrueInstr->getOperand (I).getReg (),
132- FalseInstr->getOperand (I).getReg ()))
134+ if (!AreNegatedConstantsOrSplats (TrueInstr->getOperand (I).getReg (),
135+ FalseInstr->getOperand (I).getReg ()))
133136 return false ;
134- } else if (!AreNegatedConstants (TrueReg, FalseReg))
137+ } else if (mi_match (TrueReg, MRI,
138+ m_GFMul (m_SpecificReg (FalseReg),
139+ m_GFCstOrSplat (MulConstant))) ||
140+ mi_match (FalseReg, MRI,
141+ m_GFMul (m_SpecificReg (TrueReg),
142+ m_GFCstOrSplat (MulConstant))) ||
143+ mi_match (TrueReg, MRI,
144+ m_GFMul (m_GFCstOrSplat (MulConstant),
145+ m_SpecificReg (FalseReg))) ||
146+ mi_match (FalseReg, MRI,
147+ m_GFMul (m_GFCstOrSplat (MulConstant),
148+ m_SpecificReg (TrueReg)))) {
149+ if (!MulConstant || !MulConstant->Value .isExactlyValue (-1.0 ))
150+ return false ;
151+ } else if (!AreNegatedConstantsOrSplats (TrueReg, FalseReg))
135152 return false ;
136153 }
137154
@@ -140,17 +157,28 @@ bool SPIRVCombinerHelper::matchSelectToFaceForward(MachineInstr &MI) const {
140157
141158void SPIRVCombinerHelper::applySPIRVFaceForward (MachineInstr &MI) const {
142159 // Extract the operands for N, I, and Ng from the match criteria.
143- Register CondReg, TrueReg, DotReg, DotOperand1, DotOperand2 ;
144- if (! mi_match (MI. getOperand ( 0 ). getReg (), MRI,
145- m_GISelect ( m_Reg (CondReg), m_Reg (TrueReg), m_Reg ())))
146- return ;
147- if (! mi_match (CondReg, MRI, m_GFCmp ( m_Pred (), m_Reg (DotReg), m_Reg ())) )
148- return ;
160+ Register CondReg = MI. getOperand ( 1 ). getReg () ;
161+ MachineInstr *CondInstr = MRI. getVRegDef (CondReg);
162+ Register DotReg = CondInstr-> getOperand ( 2 ). getReg ();
163+ CmpInst::Predicate Pred = cast<GFCmp>(CondInstr)-> getCond () ;
164+ if (Pred == CmpInst::FCMP_OGT || Pred == CmpInst::FCMP_UGT )
165+ DotReg = CondInstr-> getOperand ( 3 ). getReg () ;
149166 MachineInstr *DotInstr = MRI.getVRegDef (DotReg);
150- if (!mi_match (DotReg, MRI, m_GFMul (m_Reg (DotOperand1), m_Reg (DotOperand2)))) {
167+ Register DotOperand1, DotOperand2;
168+ if (DotInstr->getOpcode () == TargetOpcode::G_FMUL) {
169+ DotOperand1 = DotInstr->getOperand (1 ).getReg ();
170+ DotOperand2 = DotInstr->getOperand (2 ).getReg ();
171+ } else {
151172 DotOperand1 = DotInstr->getOperand (2 ).getReg ();
152173 DotOperand2 = DotInstr->getOperand (3 ).getReg ();
153174 }
175+ Register TrueReg = MI.getOperand (2 ).getReg ();
176+ Register FalseReg = MI.getOperand (3 ).getReg ();
177+ MachineInstr *TrueInstr = MRI.getVRegDef (TrueReg);
178+ if (TrueInstr->getOpcode () == TargetOpcode::G_FNEG ||
179+ TrueInstr->getOpcode () == TargetOpcode::G_FMUL)
180+ std::swap (TrueReg, FalseReg);
181+ MachineInstr *FalseInstr = MRI.getVRegDef (FalseReg);
154182
155183 Register ResultReg = MI.getOperand (0 ).getReg ();
156184 Builder.setInstrAndDebugLoc (MI);
@@ -159,5 +187,25 @@ void SPIRVCombinerHelper::applySPIRVFaceForward(MachineInstr &MI) const {
159187 .addUse (DotOperand1) // I
160188 .addUse (DotOperand2); // Ng
161189
162- MI.eraseFromParent ();
190+ SPIRVGlobalRegistry *GR =
191+ MI.getMF ()->getSubtarget <SPIRVSubtarget>().getSPIRVGlobalRegistry ();
192+ auto RemoveAllUses = [&](Register Reg) {
193+ SmallVector<MachineInstr *, 4 > UsesToErase;
194+ for (auto &UseMI : MRI.use_instructions (Reg))
195+ UsesToErase.push_back (&UseMI);
196+
197+ // calling eraseFromParent to early invalidates the iterator.
198+ for (auto *MIToErase : UsesToErase)
199+ MIToErase->eraseFromParent ();
200+ };
201+
202+ RemoveAllUses (CondReg); // remove all uses of FCMP Result
203+ GR->invalidateMachineInstr (CondInstr);
204+ CondInstr->eraseFromParent (); // remove FCMP instruction
205+ RemoveAllUses (DotReg); // remove all uses of spv_fdot/G_FMUL Result
206+ GR->invalidateMachineInstr (DotInstr);
207+ DotInstr->eraseFromParent (); // remove spv_fdot/G_FMUL instruction
208+ RemoveAllUses (FalseReg);
209+ GR->invalidateMachineInstr (FalseInstr);
210+ FalseInstr->eraseFromParent ();
163211}
0 commit comments