@@ -4291,3 +4291,233 @@ bool llvm::matchExtractVecEltAssertBcst(MachineInstr &MI,
42914291
42924292 return true ;
42934293}
4294+
4295+ // / Helper function to recursively check if all uses of a register are valid
4296+ // / for the unaligned extract load combiner.
4297+ // / Automatically traverses through bitcasts to validate all usage patterns.
4298+ // / Valid terminal uses are: direct extracts or pad vector operations (with use
4299+ // / check).
4300+ static bool areLoadUsesValidForExtractCombine (Register Reg,
4301+ unsigned ZExtExtractOpcode,
4302+ unsigned SExtExtractOpcode,
4303+ unsigned PadVectorOpcode,
4304+ MachineRegisterInfo &MRI) {
4305+
4306+ auto IsValidExtractOpcode = [&](unsigned Opcode) {
4307+ return Opcode == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
4308+ Opcode == ZExtExtractOpcode || Opcode == SExtExtractOpcode;
4309+ };
4310+
4311+ for (const MachineInstr &Use : MRI.use_nodbg_instructions (Reg)) {
4312+ const unsigned UseOpcode = Use.getOpcode ();
4313+
4314+ if (UseOpcode == TargetOpcode::G_BITCAST) {
4315+ // Recursively check bitcast uses
4316+ const Register BitcastDst = Use.getOperand (0 ).getReg ();
4317+ if (!areLoadUsesValidForExtractCombine (BitcastDst, ZExtExtractOpcode,
4318+ SExtExtractOpcode, PadVectorOpcode,
4319+ MRI))
4320+ return false ;
4321+ continue ;
4322+ }
4323+
4324+ if (IsValidExtractOpcode (UseOpcode)) {
4325+ // Direct extract is valid (plain, zext, or sext)
4326+ continue ;
4327+ }
4328+
4329+ if (UseOpcode == PadVectorOpcode) {
4330+ // Pad is valid if only used by extracts
4331+ const Register PadDst = Use.getOperand (0 ).getReg ();
4332+ for (const MachineInstr &PadUse : MRI.use_nodbg_instructions (PadDst)) {
4333+ if (!IsValidExtractOpcode (PadUse.getOpcode ()))
4334+ return false ;
4335+ }
4336+ continue ;
4337+ }
4338+
4339+ // Invalid use
4340+ return false ;
4341+ }
4342+
4343+ return true ;
4344+ }
4345+
4346+ // / Match unaligned vector loads that are only used for extracting elements
4347+ // / and convert them to direct scalar loads.
4348+ // / Supports s8, s16 and s32 element extractions from various vector
4349+ // / configurations. Pattern:
4350+ // / %vec:_(<N x sM>) = G_LOAD %ptr(p0) :: (align < M/8)
4351+ // / %bitcast:_(<K x sX>) = G_BITCAST %vec
4352+ // / %idx:_(s32) = G_CONSTANT i32 N
4353+ // / %elt:_(sX) = G_EXTRACT_VECTOR_ELT %bitcast, %idx
4354+ // / Or with G_AIE_PAD_VECTOR_UNDEF:
4355+ // / %vec = G_LOAD %ptr :: (unaligned)
4356+ // / %bitcast = G_BITCAST %vec
4357+ // / %padded = G_AIE_PAD_VECTOR_UNDEF %bitcast
4358+ // / %result:_(s32) = G_AIE_[Z/S]EXT_EXTRACT_VECTOR_ELT %padded, %idx
4359+ // / Converts to:
4360+ // / %offset:_(s20) = G_CONSTANT i20 (N * sizeof(sX))
4361+ // / %new_ptr:_(p0) = G_PTR_ADD %ptr, %offset
4362+ // / %elt:_(sX) = G_LOAD %new_ptr :: (align 1)
4363+ // / %result:_(s32) = G_[Z/S]EXT %elt
4364+ bool llvm::matchUnalignedExtractLoad (MachineInstr &MI, MachineRegisterInfo &MRI,
4365+ GISelChangeObserver &Observer,
4366+ BuildFnTy &MatchInfo) {
4367+ const MachineFunction &MF = *MI.getMF ();
4368+ const AIEBaseInstrInfo &TII =
4369+ *static_cast <const AIEBaseInstrInfo *>(MF.getSubtarget ().getInstrInfo ());
4370+
4371+ const unsigned Opcode = MI.getOpcode ();
4372+ const unsigned ZExtExtractOpcode =
4373+ TII.getGenericExtractVectorEltOpcode (false );
4374+ const unsigned SExtExtractOpcode = TII.getGenericExtractVectorEltOpcode (true );
4375+ const unsigned PadVectorOpcode = TII.getGenericPadVectorOpcode ();
4376+
4377+ const bool IsZExtExtract = (Opcode == ZExtExtractOpcode);
4378+ const bool IsSExtExtract = (Opcode == SExtExtractOpcode);
4379+ const bool IsPlainExtract = (Opcode == TargetOpcode::G_EXTRACT_VECTOR_ELT);
4380+
4381+ if (!IsZExtExtract && !IsSExtExtract && !IsPlainExtract)
4382+ return false ;
4383+
4384+ // Get the index operand
4385+ const Register IdxReg = MI.getOperand (2 ).getReg ();
4386+ auto IdxCst = getIConstantVRegValWithLookThrough (IdxReg, MRI);
4387+ if (!IdxCst)
4388+ return false ;
4389+ const int64_t Index = IdxCst->Value .getSExtValue ();
4390+
4391+ // Get the vector operand
4392+ const Register VecReg = MI.getOperand (1 ).getReg ();
4393+ const LLT VecTy = MRI.getType (VecReg);
4394+
4395+ // Check if vector has extractable element types (s8, s16, or s32)
4396+ if (!VecTy.isVector ())
4397+ return false ;
4398+
4399+ const LLT ElemTy = VecTy.getElementType ();
4400+ const unsigned ElemSize = ElemTy.getSizeInBits ();
4401+ if (ElemSize != 8 && ElemSize != 16 && ElemSize != 32 )
4402+ return false ;
4403+
4404+ // Trace back through G_AIE_PAD_VECTOR_UNDEF if present
4405+ MachineInstr *VecDefMI = MRI.getVRegDef (VecReg);
4406+ Register SourceVecReg = VecReg;
4407+
4408+ if (VecDefMI->getOpcode () == PadVectorOpcode) {
4409+ SourceVecReg = VecDefMI->getOperand (1 ).getReg ();
4410+ VecDefMI = MRI.getVRegDef (SourceVecReg);
4411+ }
4412+
4413+ // Check for G_BITCAST (or direct vector if no bitcast needed)
4414+ Register LoadVecReg = SourceVecReg;
4415+ if (VecDefMI->getOpcode () == TargetOpcode::G_BITCAST)
4416+ LoadVecReg = VecDefMI->getOperand (1 ).getReg ();
4417+
4418+ MachineInstr *LoadMI = MRI.getVRegDef (LoadVecReg);
4419+
4420+ // Check if it's a load
4421+ if (LoadMI->getOpcode () != TargetOpcode::G_LOAD)
4422+ return false ;
4423+
4424+ // Check if the load is unaligned relative to the vector's total size
4425+ if (LoadMI->memoperands_empty ())
4426+ return false ;
4427+
4428+ const MachineMemOperand *MMO = LoadMI->memoperands ().front ();
4429+ const LLT LoadVecTy = MRI.getType (LoadVecReg);
4430+ const unsigned LoadVecSizeInBytes = LoadVecTy.getSizeInBytes ();
4431+ // Vector is unaligned if alignment < vector size
4432+ // This allows extracting elements when the vector load itself is unaligned
4433+ if (MMO->getAlign ().value () >= LoadVecSizeInBytes)
4434+ return false ;
4435+
4436+ const unsigned ElemSizeInBytes = ElemSize / 8 ;
4437+
4438+ // Check that the loaded vector is only used by extracts (through bitcast and
4439+ // pad). The helper function will automatically traverse through bitcasts.
4440+ const Register LoadDstReg = LoadMI->getOperand (0 ).getReg ();
4441+
4442+ if (!areLoadUsesValidForExtractCombine (LoadDstReg, ZExtExtractOpcode,
4443+ SExtExtractOpcode, PadVectorOpcode,
4444+ MRI))
4445+ return false ;
4446+
4447+ // All checks passed, we can combine
4448+ MatchInfo = [=, &MI, &MRI, &Observer](MachineIRBuilder &B) {
4449+ const Register PtrReg = LoadMI->getOperand (1 ).getReg ();
4450+ const LLT S20 = LLT::scalar (20 );
4451+
4452+ // Calculate byte offset: Index * ElemSizeInBytes
4453+ const int64_t ByteOffset = Index * ElemSizeInBytes;
4454+
4455+ // Set insertion point right after the original vector load
4456+ if (LoadMI->getNextNode ())
4457+ B.setInstr (*LoadMI->getNextNode ());
4458+ else
4459+ B.setInsertPt (*LoadMI->getParent (), LoadMI->getParent ()->end ());
4460+ B.setDebugLoc (LoadMI->getDebugLoc ());
4461+
4462+ // Create offset constant and pointer add
4463+ const Register OffsetReg = B.buildConstant (S20, ByteOffset).getReg (0 );
4464+ const Register NewPtrReg =
4465+ B.buildPtrAdd (MRI.getType (PtrReg), PtrReg, OffsetReg).getReg (0 );
4466+
4467+ // Calculate alignment for scalar load based on original vector load
4468+ // alignment using GCD to find the maximum provable alignment
4469+ const unsigned OrigAlign = MMO->getAlign ().value ();
4470+ const unsigned ScalarAlign =
4471+ ByteOffset == 0 ? OrigAlign : std::gcd (OrigAlign, (unsigned )ByteOffset);
4472+
4473+ // Create new scalar load with derived alignment
4474+ MachineFunction &MF = B.getMF ();
4475+ MachineMemOperand *NewMMO =
4476+ MF.getMachineMemOperand (MMO->getPointerInfo (), MMO->getFlags (),
4477+ ElemSizeInBytes, Align (ScalarAlign));
4478+
4479+ const Register LoadResultReg = MRI.createGenericVirtualRegister (ElemTy);
4480+ B.buildLoad (LoadResultReg, NewPtrReg, *NewMMO);
4481+
4482+ // Now set insertion point at the extract position for the copy/extension
4483+ B.setInstr (MI);
4484+
4485+ // Handle the result based on the original opcode
4486+ if (IsZExtExtract || IsSExtExtract) {
4487+ // Need to extend to s32
4488+ const Register DstReg = MI.getOperand (0 ).getReg ();
4489+ if (IsZExtExtract)
4490+ B.buildZExt (DstReg, LoadResultReg);
4491+ else
4492+ B.buildSExt (DstReg, LoadResultReg);
4493+ } else {
4494+ // G_EXTRACT_VECTOR_ELT - check if there's a G_ZEXT or G_SEXT user
4495+ const Register DstReg = MI.getOperand (0 ).getReg ();
4496+ if (MRI.hasOneNonDBGUse (DstReg)) {
4497+ MachineInstr *UserMI = &*MRI.use_instr_nodbg_begin (DstReg);
4498+ const unsigned UserOpcode = UserMI->getOpcode ();
4499+ if (UserOpcode == TargetOpcode::G_ZEXT ||
4500+ UserOpcode == TargetOpcode::G_SEXT) {
4501+ // Combine the extract and ext
4502+ const Register ExtDstReg = UserMI->getOperand (0 ).getReg ();
4503+ if (UserOpcode == TargetOpcode::G_ZEXT)
4504+ B.buildZExt (ExtDstReg, LoadResultReg);
4505+ else
4506+ B.buildSExt (ExtDstReg, LoadResultReg);
4507+ Observer.erasingInstr (*UserMI);
4508+ UserMI->eraseFromParent ();
4509+ Observer.erasingInstr (MI);
4510+ MI.eraseFromParent ();
4511+ return ;
4512+ }
4513+ }
4514+ // Just copy the result
4515+ B.buildCopy (DstReg, LoadResultReg);
4516+ }
4517+
4518+ Observer.erasingInstr (MI);
4519+ MI.eraseFromParent ();
4520+ };
4521+
4522+ return true ;
4523+ }
0 commit comments