@@ -4336,7 +4336,7 @@ enum class SrcStatus {
43364336 HALF_START = IS_UPPER_HALF,
43374337 HALF_END = IS_LOWER_HALF_NEG
43384338};
4339- // Test if the MI is truncating to half, such as `%reg0:n = G_TRUNC %reg1:2n`
4339+ // / Test if the MI is truncating to half, such as `%reg0:n = G_TRUNC %reg1:2n`
43404340static bool isTruncHalf (const MachineInstr *MI,
43414341 const MachineRegisterInfo &MRI) {
43424342 if (MI->getOpcode () != AMDGPU::G_TRUNC)
@@ -4347,8 +4347,8 @@ static bool isTruncHalf(const MachineInstr *MI,
43474347 return DstSize * 2 == SrcSize;
43484348}
43494349
4350- // Test if the MI is logic shift right with half bits,
4351- // such as `%reg0:2n =G_LSHR %reg1:2n, CONST(n)`
4350+ // / Test if the MI is logic shift right with half bits,
4351+ // / such as `%reg0:2n =G_LSHR %reg1:2n, CONST(n)`
43524352static bool isLshrHalf (const MachineInstr *MI, const MachineRegisterInfo &MRI) {
43534353 if (MI->getOpcode () != AMDGPU::G_LSHR)
43544354 return false ;
@@ -4364,8 +4364,8 @@ static bool isLshrHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
43644364 return false ;
43654365}
43664366
4367- // Test if the MI is shift left with half bits,
4368- // such as `%reg0:2n =G_SHL %reg1:2n, CONST(n)`
4367+ // / Test if the MI is shift left with half bits,
4368+ // / such as `%reg0:2n =G_SHL %reg1:2n, CONST(n)`
43694369static bool isShlHalf (const MachineInstr *MI, const MachineRegisterInfo &MRI) {
43704370 if (MI->getOpcode () != AMDGPU::G_SHL)
43714371 return false ;
@@ -4381,7 +4381,7 @@ static bool isShlHalf(const MachineInstr *MI, const MachineRegisterInfo &MRI) {
43814381 return false ;
43824382}
43834383
4384- // Test function, if the MI is `%reg0:n, %reg1:n = G_UNMERGE_VALUES %reg2:2n`
4384+ // / Test function, if the MI is `%reg0:n, %reg1:n = G_UNMERGE_VALUES %reg2:2n`
43854385static bool isUnmergeHalf (const MachineInstr *MI,
43864386 const MachineRegisterInfo &MRI) {
43874387 if (MI->getOpcode () != AMDGPU::G_UNMERGE_VALUES)
@@ -4566,6 +4566,8 @@ calcNextStatus(std::pair<Register, SrcStatus> Curr,
45664566 // Handle general Opc cases.
45674567 switch (Opc) {
45684568 case AMDGPU::G_BITCAST:
4569+ return std::optional<std::pair<Register, SrcStatus>>(
4570+ {MI->getOperand (1 ).getReg (), Curr.second });
45694571 case AMDGPU::COPY:
45704572 if (MI->getOperand (1 ).getReg ().isPhysical ())
45714573 return std::nullopt ;
@@ -4641,14 +4643,19 @@ calcNextStatus(std::pair<Register, SrcStatus> Curr,
46414643 return std::nullopt ;
46424644}
46434645
4644- class searchOptions {
4646+ // / This is used to control valid status that current MI supports. For example,
4647+ // / non floating point intrinsic such as @llvm.amdgcn.sdot2 does not support NEG
4648+ // / bit on VOP3P.
4649+ // / The class can be further extended to recognize support on SEL, NEG, ABS bit
4650+ // / for different MI on different arch
4651+ class SearchOptions {
46454652private:
46464653 bool HasNeg = false ;
4647- // Assume all complex pattern of VOP3P has opsel.
4654+ // Assume all complex pattern of VOP3P have opsel.
46484655 bool HasOpsel = true ;
46494656
46504657public:
4651- searchOptions (Register Reg, const MachineRegisterInfo &MRI) {
4658+ SearchOptions (Register Reg, const MachineRegisterInfo &MRI) {
46524659 const MachineInstr *MI = MRI.getVRegDef (Reg);
46534660 unsigned Opc = MI->getOpcode ();
46544661
@@ -4676,15 +4683,15 @@ class searchOptions {
46764683};
46774684
46784685static SmallVector<std::pair<Register, SrcStatus>>
4679- getSrcStats (Register Reg, const MachineRegisterInfo &MRI,
4680- searchOptions SearchOptions, int MaxDepth = 6 ) {
4686+ getSrcStats (Register Reg, const MachineRegisterInfo &MRI, SearchOptions SO,
4687+ int MaxDepth = 3 ) {
46814688 int Depth = 0 ;
46824689 auto Curr = calcNextStatus ({Reg, SrcStatus::IS_SAME}, MRI);
46834690 SmallVector<std::pair<Register, SrcStatus>> Statlist;
46844691
46854692 while (Depth <= MaxDepth && Curr.has_value ()) {
46864693 Depth++;
4687- if (SearchOptions .checkOptions (Curr.value ().second ))
4694+ if (SO .checkOptions (Curr.value ().second ))
46884695 Statlist.push_back (Curr.value ());
46894696 Curr = calcNextStatus (Curr.value (), MRI);
46904697 }
@@ -4693,19 +4700,18 @@ getSrcStats(Register Reg, const MachineRegisterInfo &MRI,
46934700}
46944701
46954702static std::pair<Register, SrcStatus>
4696- getLastSameOrNeg (Register Reg, const MachineRegisterInfo &MRI,
4697- searchOptions SearchOptions, int MaxDepth = 6 ) {
4703+ getLastSameOrNeg (Register Reg, const MachineRegisterInfo &MRI, SearchOptions SO,
4704+ int MaxDepth = 3 ) {
46984705 int Depth = 0 ;
46994706 std::pair<Register, SrcStatus> LastSameOrNeg = {Reg, SrcStatus::IS_SAME};
47004707 auto Curr = calcNextStatus (LastSameOrNeg, MRI);
47014708
47024709 while (Depth <= MaxDepth && Curr.has_value ()) {
47034710 Depth++;
4704- if (SearchOptions.checkOptions (Curr.value ().second )) {
4705- if (Curr.value ().second == SrcStatus::IS_SAME ||
4706- Curr.value ().second == SrcStatus::IS_HI_NEG ||
4707- Curr.value ().second == SrcStatus::IS_LO_NEG ||
4708- Curr.value ().second == SrcStatus::IS_BOTH_NEG)
4711+ SrcStatus Stat = Curr.value ().second ;
4712+ if (SO.checkOptions (Stat)) {
4713+ if (Stat == SrcStatus::IS_SAME || Stat == SrcStatus::IS_HI_NEG ||
4714+ Stat == SrcStatus::IS_LO_NEG || Stat == SrcStatus::IS_BOTH_NEG)
47094715 LastSameOrNeg = Curr.value ();
47104716 }
47114717 Curr = calcNextStatus (Curr.value (), MRI);
@@ -4766,10 +4772,9 @@ std::pair<Register, unsigned> AMDGPUInstructionSelector::selectVOP3PModsImpl(
47664772 return {RootReg, Mods};
47674773 }
47684774
4769- searchOptions SearchOptions (RootReg, MRI);
4775+ SearchOptions SO (RootReg, MRI);
47704776
4771- std::pair<Register, SrcStatus> Stat =
4772- getLastSameOrNeg (RootReg, MRI, SearchOptions);
4777+ std::pair<Register, SrcStatus> Stat = getLastSameOrNeg (RootReg, MRI, SO);
47734778
47744779 if (Stat.second == SrcStatus::IS_BOTH_NEG)
47754780 Mods ^= (SISrcMods::NEG | SISrcMods::NEG_HI);
@@ -4787,15 +4792,15 @@ std::pair<Register, unsigned> AMDGPUInstructionSelector::selectVOP3PModsImpl(
47874792 }
47884793
47894794 SmallVector<std::pair<Register, SrcStatus>> StatlistHi =
4790- getSrcStats (MI->getOperand (2 ).getReg (), MRI, SearchOptions );
4795+ getSrcStats (MI->getOperand (2 ).getReg (), MRI, SO );
47914796
47924797 if (StatlistHi.empty ()) {
47934798 Mods |= SISrcMods::OP_SEL_1;
47944799 return {Stat.first , Mods};
47954800 }
47964801
47974802 SmallVector<std::pair<Register, SrcStatus>> StatlistLo =
4798- getSrcStats (MI->getOperand (1 ).getReg (), MRI, SearchOptions );
4803+ getSrcStats (MI->getOperand (1 ).getReg (), MRI, SO );
47994804
48004805 if (StatlistLo.empty ()) {
48014806 Mods |= SISrcMods::OP_SEL_1;
@@ -4869,7 +4874,7 @@ static Register getLegalRegBank(Register NewReg, Register RootReg,
48694874 BuildMI (*BB, MI, MI->getDebugLoc (), TII.get (AMDGPU::COPY), DstReg)
48704875 .addReg (NewReg);
48714876
4872- // only accept VGPR.
4877+ // Only accept VGPR.
48734878 return MIB->getOperand (0 ).getReg ();
48744879}
48754880
0 commit comments