@@ -4291,3 +4291,230 @@ bool llvm::matchExtractVecEltAssertBcst(MachineInstr &MI,
42914291
42924292 return true ;
42934293}
4294+
4295+ // / Helper function to recursively check if all uses of a register are valid
4296+ // / for the unaligned extract load combiner.
4297+ // / Automatically traverses through bitcasts to validate all usage patterns.
4298+ // / Valid terminal uses are: direct extracts or pad vector operations.
4299+ static bool areLoadUsesValidForExtractCombine (Register Reg,
4300+ unsigned ZExtExtractOpcode,
4301+ unsigned SExtExtractOpcode,
4302+ unsigned PadVectorOpcode,
4303+ MachineRegisterInfo &MRI) {
4304+
4305+ for (const MachineInstr &Use : MRI.use_nodbg_instructions (Reg)) {
4306+ const unsigned UseOpcode = Use.getOpcode ();
4307+
4308+ if (UseOpcode == TargetOpcode::G_BITCAST) {
4309+ // Recursively check bitcast uses
4310+ const Register BitcastDst = Use.getOperand (0 ).getReg ();
4311+ if (!areLoadUsesValidForExtractCombine (BitcastDst, ZExtExtractOpcode,
4312+ SExtExtractOpcode, PadVectorOpcode,
4313+ MRI))
4314+ return false ;
4315+ continue ;
4316+ }
4317+
4318+ if (UseOpcode == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
4319+ // Direct extract is valid
4320+ continue ;
4321+ }
4322+
4323+ if (UseOpcode == PadVectorOpcode) {
4324+ // Pad is valid if only used by extracts
4325+ const Register PadDst = Use.getOperand (0 ).getReg ();
4326+ for (const MachineInstr &PadUse : MRI.use_nodbg_instructions (PadDst)) {
4327+ const unsigned PadUseOpcode = PadUse.getOpcode ();
4328+ if (PadUseOpcode != ZExtExtractOpcode &&
4329+ PadUseOpcode != SExtExtractOpcode &&
4330+ PadUseOpcode != TargetOpcode::G_EXTRACT_VECTOR_ELT)
4331+ return false ;
4332+ }
4333+ continue ;
4334+ }
4335+
4336+ // Invalid use
4337+ return false ;
4338+ }
4339+
4340+ return true ;
4341+ }
4342+
4343+ // / Match unaligned vector loads that are only used for extracting elements
4344+ // / and convert them to direct scalar loads.
4345+ // / Supports s8, s16 and s32 element extractions from various vector
4346+ // / configurations. Pattern:
4347+ // / %vec:_(<N x sM>) = G_LOAD %ptr(p0) :: (align < M/8)
4348+ // / %bitcast:_(<K x sX>) = G_BITCAST %vec
4349+ // / %idx:_(s32) = G_CONSTANT i32 N
4350+ // / %elt:_(sX) = G_EXTRACT_VECTOR_ELT %bitcast, %idx
4351+ // / Or with G_AIE_PAD_VECTOR_UNDEF:
4352+ // / %vec = G_LOAD %ptr :: (unaligned)
4353+ // / %bitcast = G_BITCAST %vec
4354+ // / %padded = G_AIE_PAD_VECTOR_UNDEF %bitcast
4355+ // / %result:_(s32) = G_AIE_[Z/S]EXT_EXTRACT_VECTOR_ELT %padded, %idx
4356+ // / Converts to:
4357+ // / %offset:_(s20) = G_CONSTANT i20 (N * sizeof(sX))
4358+ // / %new_ptr:_(p0) = G_PTR_ADD %ptr, %offset
4359+ // / %elt:_(sX) = G_LOAD %new_ptr :: (align 1)
4360+ // / %result:_(s32) = G_[Z/S]EXT %elt
4361+ bool llvm::matchUnalignedExtractLoad (MachineInstr &MI, MachineRegisterInfo &MRI,
4362+ GISelChangeObserver &Observer,
4363+ BuildFnTy &MatchInfo) {
4364+ const MachineFunction &MF = *MI.getMF ();
4365+ const AIEBaseInstrInfo &TII =
4366+ *static_cast <const AIEBaseInstrInfo *>(MF.getSubtarget ().getInstrInfo ());
4367+
4368+ const unsigned Opcode = MI.getOpcode ();
4369+ const unsigned ZExtExtractOpcode =
4370+ TII.getGenericExtractVectorEltOpcode (false );
4371+ const unsigned SExtExtractOpcode = TII.getGenericExtractVectorEltOpcode (true );
4372+ const unsigned PadVectorOpcode = TII.getGenericPadVectorOpcode ();
4373+
4374+ const bool IsZExtExtract = (Opcode == ZExtExtractOpcode);
4375+ const bool IsSExtExtract = (Opcode == SExtExtractOpcode);
4376+ const bool IsPlainExtract = (Opcode == TargetOpcode::G_EXTRACT_VECTOR_ELT);
4377+
4378+ if (!IsZExtExtract && !IsSExtExtract && !IsPlainExtract)
4379+ return false ;
4380+
4381+ // Get the index operand
4382+ const Register IdxReg = MI.getOperand (2 ).getReg ();
4383+ auto IdxCst = getIConstantVRegValWithLookThrough (IdxReg, MRI);
4384+ if (!IdxCst)
4385+ return false ;
4386+ const int64_t Index = IdxCst->Value .getSExtValue ();
4387+
4388+ // Get the vector operand
4389+ const Register VecReg = MI.getOperand (1 ).getReg ();
4390+ const LLT VecTy = MRI.getType (VecReg);
4391+
4392+ // Check if vector has extractable element types (s8, s16, or s32)
4393+ if (!VecTy.isVector ())
4394+ return false ;
4395+
4396+ const LLT ElemTy = VecTy.getElementType ();
4397+ const unsigned ElemSize = ElemTy.getSizeInBits ();
4398+ if (ElemSize != 8 && ElemSize != 16 && ElemSize != 32 )
4399+ return false ;
4400+
4401+ // Trace back through G_AIE_PAD_VECTOR_UNDEF if present
4402+ MachineInstr *VecDefMI = MRI.getVRegDef (VecReg);
4403+ Register SourceVecReg = VecReg;
4404+
4405+ if (VecDefMI->getOpcode () == PadVectorOpcode) {
4406+ SourceVecReg = VecDefMI->getOperand (1 ).getReg ();
4407+ VecDefMI = MRI.getVRegDef (SourceVecReg);
4408+ }
4409+
4410+ // Check for G_BITCAST (or direct vector if no bitcast needed)
4411+ Register LoadVecReg = SourceVecReg;
4412+ if (VecDefMI->getOpcode () == TargetOpcode::G_BITCAST)
4413+ LoadVecReg = VecDefMI->getOperand (1 ).getReg ();
4414+
4415+ MachineInstr *LoadMI = MRI.getVRegDef (LoadVecReg);
4416+
4417+ // Check if it's a load
4418+ if (LoadMI->getOpcode () != TargetOpcode::G_LOAD)
4419+ return false ;
4420+
4421+ // Check if the load is unaligned relative to the vector's total size
4422+ if (LoadMI->memoperands_empty ())
4423+ return false ;
4424+
4425+ const MachineMemOperand *MMO = LoadMI->memoperands ().front ();
4426+ const LLT LoadVecTy = MRI.getType (LoadVecReg);
4427+ const unsigned LoadVecSizeInBytes = LoadVecTy.getSizeInBytes ();
4428+ // Vector is unaligned if alignment < vector size
4429+ // This allows extracting elements when the vector load itself is unaligned
4430+ if (MMO->getAlign ().value () >= LoadVecSizeInBytes)
4431+ return false ;
4432+
4433+ const unsigned ElemSizeInBytes = ElemSize / 8 ;
4434+
4435+ // Check that the loaded vector is only used by extracts (through bitcast and
4436+ // pad). The helper function will automatically traverse through bitcasts.
4437+ const Register LoadDstReg = LoadMI->getOperand (0 ).getReg ();
4438+
4439+ if (!areLoadUsesValidForExtractCombine (LoadDstReg, ZExtExtractOpcode,
4440+ SExtExtractOpcode, PadVectorOpcode,
4441+ MRI))
4442+ return false ;
4443+
4444+ // All checks passed, we can combine
4445+ MatchInfo = [=, &MI, &MRI, &Observer](MachineIRBuilder &B) {
4446+ const Register PtrReg = LoadMI->getOperand (1 ).getReg ();
4447+ const LLT S20 = LLT::scalar (20 );
4448+
4449+ // Calculate byte offset: Index * ElemSizeInBytes
4450+ const int64_t ByteOffset = Index * ElemSizeInBytes;
4451+
4452+ // Set insertion point right after the original vector load
4453+ if (LoadMI->getNextNode ())
4454+ B.setInstr (*LoadMI->getNextNode ());
4455+ else
4456+ B.setInsertPt (*LoadMI->getParent (), LoadMI->getParent ()->end ());
4457+ B.setDebugLoc (LoadMI->getDebugLoc ());
4458+
4459+ // Create offset constant and pointer add
4460+ const Register OffsetReg = B.buildConstant (S20, ByteOffset).getReg (0 );
4461+ const Register NewPtrReg =
4462+ B.buildPtrAdd (MRI.getType (PtrReg), PtrReg, OffsetReg).getReg (0 );
4463+
4464+ // Calculate alignment for scalar load based on original vector load
4465+ // alignment If the byte offset is aligned to the original alignment, use
4466+ // it; otherwise use 1
4467+ const unsigned OrigAlign = MMO->getAlign ().value ();
4468+ const unsigned ScalarAlign = (ByteOffset % OrigAlign == 0 ) ? OrigAlign : 1 ;
4469+
4470+ // Create new scalar load with derived alignment
4471+ MachineFunction &MF = B.getMF ();
4472+ MachineMemOperand *NewMMO =
4473+ MF.getMachineMemOperand (MMO->getPointerInfo (), MMO->getFlags (),
4474+ ElemSizeInBytes, Align (ScalarAlign));
4475+
4476+ const Register LoadResultReg = MRI.createGenericVirtualRegister (ElemTy);
4477+ B.buildLoad (LoadResultReg, NewPtrReg, *NewMMO);
4478+
4479+ // Now set insertion point at the extract position for the copy/extension
4480+ B.setInstr (MI);
4481+
4482+ // Handle the result based on the original opcode
4483+ if (IsZExtExtract || IsSExtExtract) {
4484+ // Need to extend to s32
4485+ const Register DstReg = MI.getOperand (0 ).getReg ();
4486+ if (IsZExtExtract)
4487+ B.buildZExt (DstReg, LoadResultReg);
4488+ else
4489+ B.buildSExt (DstReg, LoadResultReg);
4490+ } else {
4491+ // G_EXTRACT_VECTOR_ELT - check if there's a G_ZEXT or G_SEXT user
4492+ const Register DstReg = MI.getOperand (0 ).getReg ();
4493+ if (MRI.hasOneNonDBGUse (DstReg)) {
4494+ MachineInstr *UserMI = &*MRI.use_instr_nodbg_begin (DstReg);
4495+ const unsigned UserOpcode = UserMI->getOpcode ();
4496+ if (UserOpcode == TargetOpcode::G_ZEXT ||
4497+ UserOpcode == TargetOpcode::G_SEXT) {
4498+ // Combine the extract and ext
4499+ const Register ExtDstReg = UserMI->getOperand (0 ).getReg ();
4500+ if (UserOpcode == TargetOpcode::G_ZEXT)
4501+ B.buildZExt (ExtDstReg, LoadResultReg);
4502+ else
4503+ B.buildSExt (ExtDstReg, LoadResultReg);
4504+ Observer.erasingInstr (*UserMI);
4505+ UserMI->eraseFromParent ();
4506+ Observer.erasingInstr (MI);
4507+ MI.eraseFromParent ();
4508+ return ;
4509+ }
4510+ }
4511+ // Just copy the result
4512+ B.buildCopy (DstReg, LoadResultReg);
4513+ }
4514+
4515+ Observer.erasingInstr (MI);
4516+ MI.eraseFromParent ();
4517+ };
4518+
4519+ return true ;
4520+ }
0 commit comments