@@ -4292,3 +4292,210 @@ bool llvm::matchExtractVecEltAssertBcst(MachineInstr &MI,
42924292
42934293 return true ;
42944294}
4295+
4296+ // / Helper function to recursively check if all uses of a register are valid
4297+ // / for the unaligned extract load combiner.
4298+ // / Automatically traverses through bitcasts to validate all usage patterns.
4299+ // / Valid terminal uses are: direct extracts or pad vector operations (with use
4300+ // / check).
4301+ static bool areLoadUsesValidForExtractCombine (Register Reg,
4302+ unsigned ZExtExtractOpcode,
4303+ unsigned SExtExtractOpcode,
4304+ unsigned PadVectorOpcode,
4305+ MachineRegisterInfo &MRI) {
4306+
4307+ auto IsValidExtractOpcode = [&](unsigned Opcode) {
4308+ return Opcode == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
4309+ Opcode == ZExtExtractOpcode || Opcode == SExtExtractOpcode;
4310+ };
4311+
4312+ for (const MachineInstr &Use : MRI.use_nodbg_instructions (Reg)) {
4313+ const unsigned UseOpcode = Use.getOpcode ();
4314+
4315+ if (UseOpcode == TargetOpcode::G_BITCAST) {
4316+ // Recursively check bitcast uses
4317+ const Register BitcastDst = Use.getOperand (0 ).getReg ();
4318+ if (!areLoadUsesValidForExtractCombine (BitcastDst, ZExtExtractOpcode,
4319+ SExtExtractOpcode, PadVectorOpcode,
4320+ MRI))
4321+ return false ;
4322+ continue ;
4323+ }
4324+
4325+ if (IsValidExtractOpcode (UseOpcode)) {
4326+ // Direct extract is valid (plain, zext, or sext)
4327+ continue ;
4328+ }
4329+
4330+ if (UseOpcode == PadVectorOpcode) {
4331+ // Pad is valid if only used by extracts
4332+ const Register PadDst = Use.getOperand (0 ).getReg ();
4333+ for (const MachineInstr &PadUse : MRI.use_nodbg_instructions (PadDst)) {
4334+ if (!IsValidExtractOpcode (PadUse.getOpcode ()))
4335+ return false ;
4336+ }
4337+ continue ;
4338+ }
4339+
4340+ // Invalid use
4341+ return false ;
4342+ }
4343+
4344+ return true ;
4345+ }
4346+
4347+ // / Match unaligned vector loads that are only used for extracting elements
4348+ // / and convert them to direct scalar loads.
4349+ // / Supports s8, s16 and s32 element extractions from various vector
4350+ // / configurations. Pattern:
4351+ // / %vec:_(<N x sM>) = G_LOAD %ptr(p0) :: (align < M/8)
4352+ // / %bitcast:_(<K x sX>) = G_BITCAST %vec
4353+ // / %idx:_(s32) = G_CONSTANT i32 N
4354+ // / %elt:_(sX) = G_EXTRACT_VECTOR_ELT %bitcast, %idx
4355+ // / Or with G_AIE_PAD_VECTOR_UNDEF:
4356+ // / %vec = G_LOAD %ptr :: (unaligned)
4357+ // / %bitcast = G_BITCAST %vec
4358+ // / %padded = G_AIE_PAD_VECTOR_UNDEF %bitcast
4359+ // / %result:_(s32) = G_AIE_[Z/S]EXT_EXTRACT_VECTOR_ELT %padded, %idx
4360+ // / Converts to:
4361+ // / %offset:_(s20) = G_CONSTANT i20 (N * sizeof(sX))
4362+ // / %new_ptr:_(p0) = G_PTR_ADD %ptr, %offset
4363+ // / %elt:_(sX) = G_LOAD %new_ptr :: (align 1)
4364+ // / %result:_(s32) = G_[Z/S]EXT %elt
4365+ bool llvm::matchUnalignedExtractLoad (MachineInstr &ExtractMI,
4366+ MachineRegisterInfo &MRI,
4367+ GISelChangeObserver &Observer,
4368+ BuildFnTy &MatchInfo) {
4369+ const MachineFunction &MF = *ExtractMI.getMF ();
4370+ const AIEBaseInstrInfo &TII =
4371+ *static_cast <const AIEBaseInstrInfo *>(MF.getSubtarget ().getInstrInfo ());
4372+
4373+ const unsigned Opcode = ExtractMI.getOpcode ();
4374+ const unsigned ZExtExtractOpcode =
4375+ TII.getGenericExtractVectorEltOpcode (false );
4376+ const unsigned SExtExtractOpcode = TII.getGenericExtractVectorEltOpcode (true );
4377+ const unsigned PadVectorOpcode = TII.getGenericPadVectorOpcode ();
4378+
4379+ const bool IsZExtExtract = (Opcode == ZExtExtractOpcode);
4380+ const bool IsSExtExtract = (Opcode == SExtExtractOpcode);
4381+ const bool IsPlainExtract = (Opcode == TargetOpcode::G_EXTRACT_VECTOR_ELT);
4382+
4383+ if (!IsZExtExtract && !IsSExtExtract && !IsPlainExtract)
4384+ return false ;
4385+
4386+ // Get the index operand
4387+ const Register IdxReg = ExtractMI.getOperand (2 ).getReg ();
4388+ const auto IdxCst = getIConstantVRegValWithLookThrough (IdxReg, MRI);
4389+ if (!IdxCst)
4390+ return false ;
4391+ const int64_t Index = IdxCst->Value .getSExtValue ();
4392+
4393+ // Get the vector operand
4394+ const Register VecReg = ExtractMI.getOperand (1 ).getReg ();
4395+ const LLT VecTy = MRI.getType (VecReg);
4396+
4397+ // Check if vector has extractable element types (s8, s16, or s32)
4398+ if (!VecTy.isVector ())
4399+ return false ;
4400+
4401+ const LLT ElemTy = VecTy.getElementType ();
4402+ const unsigned ElemSize = ElemTy.getSizeInBits ();
4403+ if (ElemSize != 8 && ElemSize != 16 && ElemSize != 32 )
4404+ return false ;
4405+
4406+ // Trace back through G_AIE_PAD_VECTOR_UNDEF if present
4407+ MachineInstr *VecDefMI = MRI.getVRegDef (VecReg);
4408+ Register SourceVecReg = VecReg;
4409+
4410+ if (VecDefMI->getOpcode () == PadVectorOpcode) {
4411+ SourceVecReg = VecDefMI->getOperand (1 ).getReg ();
4412+ VecDefMI = MRI.getVRegDef (SourceVecReg);
4413+ }
4414+
4415+ // Check for G_BITCAST (or direct vector if no bitcast needed)
4416+ Register LoadVecReg = SourceVecReg;
4417+ if (VecDefMI->getOpcode () == TargetOpcode::G_BITCAST)
4418+ LoadVecReg = VecDefMI->getOperand (1 ).getReg ();
4419+
4420+ MachineInstr *LoadMI = MRI.getVRegDef (LoadVecReg);
4421+
4422+ // Check if it's a load
4423+ if (LoadMI->getOpcode () != TargetOpcode::G_LOAD)
4424+ return false ;
4425+
4426+ // Check if the load is unaligned relative to the vector's total size
4427+ if (LoadMI->memoperands_empty ())
4428+ return false ;
4429+
4430+ const MachineMemOperand *MMO = LoadMI->memoperands ().front ();
4431+ const LLT LoadVecTy = MRI.getType (LoadVecReg);
4432+ const unsigned LoadVecSizeInBytes = LoadVecTy.getSizeInBytes ();
4433+ // Vector is unaligned if alignment < vector size
4434+ // This allows extracting elements when the vector load itself is unaligned
4435+ if (MMO->getAlign ().value () >= LoadVecSizeInBytes)
4436+ return false ;
4437+
4438+ // Check that the loaded vector is only used by extracts (through bitcast and
4439+ // pad). The helper function will automatically traverse through bitcasts.
4440+ const Register LoadDstReg = LoadMI->getOperand (0 ).getReg ();
4441+
4442+ if (!areLoadUsesValidForExtractCombine (LoadDstReg, ZExtExtractOpcode,
4443+ SExtExtractOpcode, PadVectorOpcode,
4444+ MRI))
4445+ return false ;
4446+
4447+ // All checks passed, we can combine
4448+ MatchInfo = [=, &ExtractMI, &MRI, &Observer](MachineIRBuilder &B) {
4449+ const Register PtrReg = LoadMI->getOperand (1 ).getReg ();
4450+ const LLT S20 = LLT::scalar (20 );
4451+
4452+ const unsigned ElemSizeInBytes = ElemSize / 8 ;
4453+ const int64_t ByteOffset = Index * ElemSizeInBytes;
4454+
4455+ // Set insertion point right after the original vector load
4456+ B.setInsertPt (*LoadMI->getParent (), std::next (LoadMI->getIterator ()));
4457+ B.setDebugLoc (LoadMI->getDebugLoc ());
4458+
4459+ // Create offset constant and pointer add
4460+ const Register OffsetReg = B.buildConstant (S20, ByteOffset).getReg (0 );
4461+ const Register NewPtrReg =
4462+ B.buildPtrAdd (MRI.getType (PtrReg), PtrReg, OffsetReg).getReg (0 );
4463+
4464+ // Calculate alignment for scalar load based on original vector load
4465+ // alignment using GCD to find the maximum provable alignment
4466+ const unsigned OrigAlign = MMO->getAlign ().value ();
4467+ const unsigned ScalarAlign = std::gcd (OrigAlign, OrigAlign + ByteOffset);
4468+
4469+ // Create new scalar load with derived alignment
4470+ MachineFunction &MF = B.getMF ();
4471+ MachineMemOperand *NewMMO =
4472+ MF.getMachineMemOperand (MMO->getPointerInfo (), MMO->getFlags (),
4473+ ElemSizeInBytes, Align (ScalarAlign));
4474+
4475+ const Register LoadResultReg = MRI.createGenericVirtualRegister (ElemTy);
4476+ B.buildLoad (LoadResultReg, NewPtrReg, *NewMMO);
4477+
4478+ // Now set insertion point at the extract position for the copy/extension
4479+ B.setInstr (ExtractMI);
4480+
4481+ // Handle the result based on the original opcode
4482+ if (IsZExtExtract || IsSExtExtract) {
4483+ // Need to extend to s32
4484+ const Register DstReg = ExtractMI.getOperand (0 ).getReg ();
4485+ if (IsZExtExtract)
4486+ B.buildZExt (DstReg, LoadResultReg);
4487+ else
4488+ B.buildSExt (DstReg, LoadResultReg);
4489+ } else {
4490+ // G_EXTRACT_VECTOR_ELT
4491+ const Register DstReg = ExtractMI.getOperand (0 ).getReg ();
4492+ // Just copy the result
4493+ B.buildCopy (DstReg, LoadResultReg);
4494+ }
4495+
4496+ Observer.erasingInstr (ExtractMI);
4497+ ExtractMI.eraseFromParent ();
4498+ };
4499+
4500+ return true ;
4501+ }
0 commit comments