Skip to content

Commit 4838126

Browse files
committed
[AIEX] Add a combiner to handle extract from unaligned vector load
In this case, we can load the scalar value directly instead of building a full vector (legalizer will scalarize this load anyway) to extract.
1 parent 7f68bce commit 4838126

File tree

4 files changed

+1068
-2
lines changed

4 files changed

+1068
-2
lines changed

llvm/lib/Target/AIE/AIECombine.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ def combine_trunc_load : GICombineRule<
252252
[{ return matchNarrowTruncLoad(*${root}, MRI, Helper, Observer, ${matchinfo}); }]),
253253
(apply [{ Helper.applyBuildFnNoErase(*${root}, ${matchinfo}); }])>;
254254

255+
def combine_unaligned_extract_load : GICombineRule<
256+
(defs root:$root, build_fn_matchinfo:$matchinfo),
257+
(match (wip_match_opcode G_EXTRACT_VECTOR_ELT, G_AIE_ZEXT_EXTRACT_VECTOR_ELT, G_AIE_SEXT_EXTRACT_VECTOR_ELT): $root,
258+
[{ return matchUnalignedExtractLoad(*${root}, MRI, Observer, ${matchinfo}); }]),
259+
(apply [{ Helper.applyBuildFnNoErase(*${root}, ${matchinfo}); }])>;
260+
255261
// AIE-specifc combines (currently shared by AIE2 and AIE2P).
256262
def aie_additional_combines : GICombineGroup<[
257263
combine_unpad_vector,
@@ -274,7 +280,8 @@ def aie_additional_combines : GICombineGroup<[
274280
combine_align_memset,
275281
combine_peel_memset,
276282
combine_pack_stores_into_memset,
277-
combine_trunc_load
283+
combine_trunc_load,
284+
combine_unaligned_extract_load
278285
]>;
279286

280287
// AIE2P-specific combines.
@@ -408,4 +415,3 @@ def AIE2PPostLegalizerCustomCombiner
408415
combine_add_vector_elt_undef,
409416
]> {
410417
}
411-

llvm/lib/Target/AIE/AIECombinerHelper.cpp

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4291,3 +4291,233 @@ bool llvm::matchExtractVecEltAssertBcst(MachineInstr &MI,
42914291

42924292
return true;
42934293
}
4294+
4295+
/// Helper function to recursively check if all uses of a register are valid
4296+
/// for the unaligned extract load combiner.
4297+
/// Automatically traverses through bitcasts to validate all usage patterns.
4298+
/// Valid terminal uses are: direct extracts or pad vector operations (with use
4299+
/// check).
4300+
static bool areLoadUsesValidForExtractCombine(Register Reg,
4301+
unsigned ZExtExtractOpcode,
4302+
unsigned SExtExtractOpcode,
4303+
unsigned PadVectorOpcode,
4304+
MachineRegisterInfo &MRI) {
4305+
4306+
auto IsValidExtractOpcode = [&](unsigned Opcode) {
4307+
return Opcode == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
4308+
Opcode == ZExtExtractOpcode || Opcode == SExtExtractOpcode;
4309+
};
4310+
4311+
for (const MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
4312+
const unsigned UseOpcode = Use.getOpcode();
4313+
4314+
if (UseOpcode == TargetOpcode::G_BITCAST) {
4315+
// Recursively check bitcast uses
4316+
const Register BitcastDst = Use.getOperand(0).getReg();
4317+
if (!areLoadUsesValidForExtractCombine(BitcastDst, ZExtExtractOpcode,
4318+
SExtExtractOpcode, PadVectorOpcode,
4319+
MRI))
4320+
return false;
4321+
continue;
4322+
}
4323+
4324+
if (IsValidExtractOpcode(UseOpcode)) {
4325+
// Direct extract is valid (plain, zext, or sext)
4326+
continue;
4327+
}
4328+
4329+
if (UseOpcode == PadVectorOpcode) {
4330+
// Pad is valid if only used by extracts
4331+
const Register PadDst = Use.getOperand(0).getReg();
4332+
for (const MachineInstr &PadUse : MRI.use_nodbg_instructions(PadDst)) {
4333+
if (!IsValidExtractOpcode(PadUse.getOpcode()))
4334+
return false;
4335+
}
4336+
continue;
4337+
}
4338+
4339+
// Invalid use
4340+
return false;
4341+
}
4342+
4343+
return true;
4344+
}
4345+
4346+
/// Match unaligned vector loads that are only used for extracting elements
4347+
/// and convert them to direct scalar loads.
4348+
/// Supports s8, s16 and s32 element extractions from various vector
4349+
/// configurations. Pattern:
4350+
/// %vec:_(<N x sM>) = G_LOAD %ptr(p0) :: (align < M/8)
4351+
/// %bitcast:_(<K x sX>) = G_BITCAST %vec
4352+
/// %idx:_(s32) = G_CONSTANT i32 N
4353+
/// %elt:_(sX) = G_EXTRACT_VECTOR_ELT %bitcast, %idx
4354+
/// Or with G_AIE_PAD_VECTOR_UNDEF:
4355+
/// %vec = G_LOAD %ptr :: (unaligned)
4356+
/// %bitcast = G_BITCAST %vec
4357+
/// %padded = G_AIE_PAD_VECTOR_UNDEF %bitcast
4358+
/// %result:_(s32) = G_AIE_[Z/S]EXT_EXTRACT_VECTOR_ELT %padded, %idx
4359+
/// Converts to:
4360+
/// %offset:_(s20) = G_CONSTANT i20 (N * sizeof(sX))
4361+
/// %new_ptr:_(p0) = G_PTR_ADD %ptr, %offset
4362+
/// %elt:_(sX) = G_LOAD %new_ptr :: (align 1)
4363+
/// %result:_(s32) = G_[Z/S]EXT %elt
4364+
bool llvm::matchUnalignedExtractLoad(MachineInstr &MI, MachineRegisterInfo &MRI,
4365+
GISelChangeObserver &Observer,
4366+
BuildFnTy &MatchInfo) {
4367+
const MachineFunction &MF = *MI.getMF();
4368+
const AIEBaseInstrInfo &TII =
4369+
*static_cast<const AIEBaseInstrInfo *>(MF.getSubtarget().getInstrInfo());
4370+
4371+
const unsigned Opcode = MI.getOpcode();
4372+
const unsigned ZExtExtractOpcode =
4373+
TII.getGenericExtractVectorEltOpcode(false);
4374+
const unsigned SExtExtractOpcode = TII.getGenericExtractVectorEltOpcode(true);
4375+
const unsigned PadVectorOpcode = TII.getGenericPadVectorOpcode();
4376+
4377+
const bool IsZExtExtract = (Opcode == ZExtExtractOpcode);
4378+
const bool IsSExtExtract = (Opcode == SExtExtractOpcode);
4379+
const bool IsPlainExtract = (Opcode == TargetOpcode::G_EXTRACT_VECTOR_ELT);
4380+
4381+
if (!IsZExtExtract && !IsSExtExtract && !IsPlainExtract)
4382+
return false;
4383+
4384+
// Get the index operand
4385+
const Register IdxReg = MI.getOperand(2).getReg();
4386+
auto IdxCst = getIConstantVRegValWithLookThrough(IdxReg, MRI);
4387+
if (!IdxCst)
4388+
return false;
4389+
const int64_t Index = IdxCst->Value.getSExtValue();
4390+
4391+
// Get the vector operand
4392+
const Register VecReg = MI.getOperand(1).getReg();
4393+
const LLT VecTy = MRI.getType(VecReg);
4394+
4395+
// Check if vector has extractable element types (s8, s16, or s32)
4396+
if (!VecTy.isVector())
4397+
return false;
4398+
4399+
const LLT ElemTy = VecTy.getElementType();
4400+
const unsigned ElemSize = ElemTy.getSizeInBits();
4401+
if (ElemSize != 8 && ElemSize != 16 && ElemSize != 32)
4402+
return false;
4403+
4404+
// Trace back through G_AIE_PAD_VECTOR_UNDEF if present
4405+
MachineInstr *VecDefMI = MRI.getVRegDef(VecReg);
4406+
Register SourceVecReg = VecReg;
4407+
4408+
if (VecDefMI->getOpcode() == PadVectorOpcode) {
4409+
SourceVecReg = VecDefMI->getOperand(1).getReg();
4410+
VecDefMI = MRI.getVRegDef(SourceVecReg);
4411+
}
4412+
4413+
// Check for G_BITCAST (or direct vector if no bitcast needed)
4414+
Register LoadVecReg = SourceVecReg;
4415+
if (VecDefMI->getOpcode() == TargetOpcode::G_BITCAST)
4416+
LoadVecReg = VecDefMI->getOperand(1).getReg();
4417+
4418+
MachineInstr *LoadMI = MRI.getVRegDef(LoadVecReg);
4419+
4420+
// Check if it's a load
4421+
if (LoadMI->getOpcode() != TargetOpcode::G_LOAD)
4422+
return false;
4423+
4424+
// Check if the load is unaligned relative to the vector's total size
4425+
if (LoadMI->memoperands_empty())
4426+
return false;
4427+
4428+
const MachineMemOperand *MMO = LoadMI->memoperands().front();
4429+
const LLT LoadVecTy = MRI.getType(LoadVecReg);
4430+
const unsigned LoadVecSizeInBytes = LoadVecTy.getSizeInBytes();
4431+
// Vector is unaligned if alignment < vector size
4432+
// This allows extracting elements when the vector load itself is unaligned
4433+
if (MMO->getAlign().value() >= LoadVecSizeInBytes)
4434+
return false;
4435+
4436+
const unsigned ElemSizeInBytes = ElemSize / 8;
4437+
4438+
// Check that the loaded vector is only used by extracts (through bitcast and
4439+
// pad). The helper function will automatically traverse through bitcasts.
4440+
const Register LoadDstReg = LoadMI->getOperand(0).getReg();
4441+
4442+
if (!areLoadUsesValidForExtractCombine(LoadDstReg, ZExtExtractOpcode,
4443+
SExtExtractOpcode, PadVectorOpcode,
4444+
MRI))
4445+
return false;
4446+
4447+
// All checks passed, we can combine
4448+
MatchInfo = [=, &MI, &MRI, &Observer](MachineIRBuilder &B) {
4449+
const Register PtrReg = LoadMI->getOperand(1).getReg();
4450+
const LLT S20 = LLT::scalar(20);
4451+
4452+
// Calculate byte offset: Index * ElemSizeInBytes
4453+
const int64_t ByteOffset = Index * ElemSizeInBytes;
4454+
4455+
// Set insertion point right after the original vector load
4456+
if (LoadMI->getNextNode())
4457+
B.setInstr(*LoadMI->getNextNode());
4458+
else
4459+
B.setInsertPt(*LoadMI->getParent(), LoadMI->getParent()->end());
4460+
B.setDebugLoc(LoadMI->getDebugLoc());
4461+
4462+
// Create offset constant and pointer add
4463+
const Register OffsetReg = B.buildConstant(S20, ByteOffset).getReg(0);
4464+
const Register NewPtrReg =
4465+
B.buildPtrAdd(MRI.getType(PtrReg), PtrReg, OffsetReg).getReg(0);
4466+
4467+
// Calculate alignment for scalar load based on original vector load
4468+
// alignment using GCD to find the maximum provable alignment
4469+
const unsigned OrigAlign = MMO->getAlign().value();
4470+
const unsigned ScalarAlign =
4471+
ByteOffset == 0 ? OrigAlign : std::gcd(OrigAlign, (unsigned)ByteOffset);
4472+
4473+
// Create new scalar load with derived alignment
4474+
MachineFunction &MF = B.getMF();
4475+
MachineMemOperand *NewMMO =
4476+
MF.getMachineMemOperand(MMO->getPointerInfo(), MMO->getFlags(),
4477+
ElemSizeInBytes, Align(ScalarAlign));
4478+
4479+
const Register LoadResultReg = MRI.createGenericVirtualRegister(ElemTy);
4480+
B.buildLoad(LoadResultReg, NewPtrReg, *NewMMO);
4481+
4482+
// Now set insertion point at the extract position for the copy/extension
4483+
B.setInstr(MI);
4484+
4485+
// Handle the result based on the original opcode
4486+
if (IsZExtExtract || IsSExtExtract) {
4487+
// Need to extend to s32
4488+
const Register DstReg = MI.getOperand(0).getReg();
4489+
if (IsZExtExtract)
4490+
B.buildZExt(DstReg, LoadResultReg);
4491+
else
4492+
B.buildSExt(DstReg, LoadResultReg);
4493+
} else {
4494+
// G_EXTRACT_VECTOR_ELT - check if there's a G_ZEXT or G_SEXT user
4495+
const Register DstReg = MI.getOperand(0).getReg();
4496+
if (MRI.hasOneNonDBGUse(DstReg)) {
4497+
MachineInstr *UserMI = &*MRI.use_instr_nodbg_begin(DstReg);
4498+
const unsigned UserOpcode = UserMI->getOpcode();
4499+
if (UserOpcode == TargetOpcode::G_ZEXT ||
4500+
UserOpcode == TargetOpcode::G_SEXT) {
4501+
// Combine the extract and ext
4502+
const Register ExtDstReg = UserMI->getOperand(0).getReg();
4503+
if (UserOpcode == TargetOpcode::G_ZEXT)
4504+
B.buildZExt(ExtDstReg, LoadResultReg);
4505+
else
4506+
B.buildSExt(ExtDstReg, LoadResultReg);
4507+
Observer.erasingInstr(*UserMI);
4508+
UserMI->eraseFromParent();
4509+
Observer.erasingInstr(MI);
4510+
MI.eraseFromParent();
4511+
return;
4512+
}
4513+
}
4514+
// Just copy the result
4515+
B.buildCopy(DstReg, LoadResultReg);
4516+
}
4517+
4518+
Observer.erasingInstr(MI);
4519+
MI.eraseFromParent();
4520+
};
4521+
4522+
return true;
4523+
}

llvm/lib/Target/AIE/AIECombinerHelper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ bool matchExtractVecEltAssertBcst(MachineInstr &MI, MachineRegisterInfo &MRI,
308308
const AIEBaseInstrInfo &TII,
309309
GISelChangeObserver &Observer,
310310
BuildFnTy &MatchInfo);
311+
312+
bool matchUnalignedExtractLoad(MachineInstr &MI, MachineRegisterInfo &MRI,
313+
GISelChangeObserver &Observer,
314+
BuildFnTy &MatchInfo);
311315
} // namespace llvm
312316

313317
#endif

0 commit comments

Comments
 (0)