Skip to content

Commit d3c310b

Browse files
committed
[AIEX] Add a combiner to handle extract from unaligned vector load
In this case, we can load the scalar value directly instead of building a full vector (legalizer will scalarize this load anyway) to extract.
1 parent 7f68bce commit d3c310b

File tree

4 files changed

+873
-2
lines changed

4 files changed

+873
-2
lines changed

llvm/lib/Target/AIE/AIECombine.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ def combine_trunc_load : GICombineRule<
252252
[{ return matchNarrowTruncLoad(*${root}, MRI, Helper, Observer, ${matchinfo}); }]),
253253
(apply [{ Helper.applyBuildFnNoErase(*${root}, ${matchinfo}); }])>;
254254

255+
def combine_unaligned_extract_load : GICombineRule<
256+
(defs root:$root, build_fn_matchinfo:$matchinfo),
257+
(match (wip_match_opcode G_EXTRACT_VECTOR_ELT, G_AIE_ZEXT_EXTRACT_VECTOR_ELT, G_AIE_SEXT_EXTRACT_VECTOR_ELT): $root,
258+
[{ return matchUnalignedExtractLoad(*${root}, MRI, Observer, ${matchinfo}); }]),
259+
(apply [{ Helper.applyBuildFnNoErase(*${root}, ${matchinfo}); }])>;
260+
255261
// AIE-specifc combines (currently shared by AIE2 and AIE2P).
256262
def aie_additional_combines : GICombineGroup<[
257263
combine_unpad_vector,
@@ -274,7 +280,8 @@ def aie_additional_combines : GICombineGroup<[
274280
combine_align_memset,
275281
combine_peel_memset,
276282
combine_pack_stores_into_memset,
277-
combine_trunc_load
283+
combine_trunc_load,
284+
combine_unaligned_extract_load
278285
]>;
279286

280287
// AIE2P-specific combines.
@@ -408,4 +415,3 @@ def AIE2PPostLegalizerCustomCombiner
408415
combine_add_vector_elt_undef,
409416
]> {
410417
}
411-

llvm/lib/Target/AIE/AIECombinerHelper.cpp

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4291,3 +4291,230 @@ bool llvm::matchExtractVecEltAssertBcst(MachineInstr &MI,
42914291

42924292
return true;
42934293
}
4294+
4295+
/// Helper function to recursively check if all uses of a register are valid
4296+
/// for the unaligned extract load combiner.
4297+
/// Automatically traverses through bitcasts to validate all usage patterns.
4298+
/// Valid terminal uses are: direct extracts or pad vector operations.
4299+
static bool areLoadUsesValidForExtractCombine(Register Reg,
4300+
unsigned ZExtExtractOpcode,
4301+
unsigned SExtExtractOpcode,
4302+
unsigned PadVectorOpcode,
4303+
MachineRegisterInfo &MRI) {
4304+
4305+
for (const MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
4306+
const unsigned UseOpcode = Use.getOpcode();
4307+
4308+
if (UseOpcode == TargetOpcode::G_BITCAST) {
4309+
// Recursively check bitcast uses
4310+
const Register BitcastDst = Use.getOperand(0).getReg();
4311+
if (!areLoadUsesValidForExtractCombine(BitcastDst, ZExtExtractOpcode,
4312+
SExtExtractOpcode, PadVectorOpcode,
4313+
MRI))
4314+
return false;
4315+
continue;
4316+
}
4317+
4318+
if (UseOpcode == TargetOpcode::G_EXTRACT_VECTOR_ELT) {
4319+
// Direct extract is valid
4320+
continue;
4321+
}
4322+
4323+
if (UseOpcode == PadVectorOpcode) {
4324+
// Pad is valid if only used by extracts
4325+
const Register PadDst = Use.getOperand(0).getReg();
4326+
for (const MachineInstr &PadUse : MRI.use_nodbg_instructions(PadDst)) {
4327+
const unsigned PadUseOpcode = PadUse.getOpcode();
4328+
if (PadUseOpcode != ZExtExtractOpcode &&
4329+
PadUseOpcode != SExtExtractOpcode &&
4330+
PadUseOpcode != TargetOpcode::G_EXTRACT_VECTOR_ELT)
4331+
return false;
4332+
}
4333+
continue;
4334+
}
4335+
4336+
// Invalid use
4337+
return false;
4338+
}
4339+
4340+
return true;
4341+
}
4342+
4343+
/// Match unaligned vector loads that are only used for extracting elements
4344+
/// and convert them to direct scalar loads.
4345+
/// Supports s8, s16 and s32 element extractions from various vector
4346+
/// configurations. Pattern:
4347+
/// %vec:_(<N x sM>) = G_LOAD %ptr(p0) :: (align < M/8)
4348+
/// %bitcast:_(<K x sX>) = G_BITCAST %vec
4349+
/// %idx:_(s32) = G_CONSTANT i32 N
4350+
/// %elt:_(sX) = G_EXTRACT_VECTOR_ELT %bitcast, %idx
4351+
/// Or with G_AIE_PAD_VECTOR_UNDEF:
4352+
/// %vec = G_LOAD %ptr :: (unaligned)
4353+
/// %bitcast = G_BITCAST %vec
4354+
/// %padded = G_AIE_PAD_VECTOR_UNDEF %bitcast
4355+
/// %result:_(s32) = G_AIE_[Z/S]EXT_EXTRACT_VECTOR_ELT %padded, %idx
4356+
/// Converts to:
4357+
/// %offset:_(s20) = G_CONSTANT i20 (N * sizeof(sX))
4358+
/// %new_ptr:_(p0) = G_PTR_ADD %ptr, %offset
4359+
/// %elt:_(sX) = G_LOAD %new_ptr :: (align 1)
4360+
/// %result:_(s32) = G_[Z/S]EXT %elt
4361+
bool llvm::matchUnalignedExtractLoad(MachineInstr &MI, MachineRegisterInfo &MRI,
4362+
GISelChangeObserver &Observer,
4363+
BuildFnTy &MatchInfo) {
4364+
const MachineFunction &MF = *MI.getMF();
4365+
const AIEBaseInstrInfo &TII =
4366+
*static_cast<const AIEBaseInstrInfo *>(MF.getSubtarget().getInstrInfo());
4367+
4368+
const unsigned Opcode = MI.getOpcode();
4369+
const unsigned ZExtExtractOpcode =
4370+
TII.getGenericExtractVectorEltOpcode(false);
4371+
const unsigned SExtExtractOpcode = TII.getGenericExtractVectorEltOpcode(true);
4372+
const unsigned PadVectorOpcode = TII.getGenericPadVectorOpcode();
4373+
4374+
const bool IsZExtExtract = (Opcode == ZExtExtractOpcode);
4375+
const bool IsSExtExtract = (Opcode == SExtExtractOpcode);
4376+
const bool IsPlainExtract = (Opcode == TargetOpcode::G_EXTRACT_VECTOR_ELT);
4377+
4378+
if (!IsZExtExtract && !IsSExtExtract && !IsPlainExtract)
4379+
return false;
4380+
4381+
// Get the index operand
4382+
const Register IdxReg = MI.getOperand(2).getReg();
4383+
auto IdxCst = getIConstantVRegValWithLookThrough(IdxReg, MRI);
4384+
if (!IdxCst)
4385+
return false;
4386+
const int64_t Index = IdxCst->Value.getSExtValue();
4387+
4388+
// Get the vector operand
4389+
const Register VecReg = MI.getOperand(1).getReg();
4390+
const LLT VecTy = MRI.getType(VecReg);
4391+
4392+
// Check if vector has extractable element types (s8, s16, or s32)
4393+
if (!VecTy.isVector())
4394+
return false;
4395+
4396+
const LLT ElemTy = VecTy.getElementType();
4397+
const unsigned ElemSize = ElemTy.getSizeInBits();
4398+
if (ElemSize != 8 && ElemSize != 16 && ElemSize != 32)
4399+
return false;
4400+
4401+
// Trace back through G_AIE_PAD_VECTOR_UNDEF if present
4402+
MachineInstr *VecDefMI = MRI.getVRegDef(VecReg);
4403+
Register SourceVecReg = VecReg;
4404+
4405+
if (VecDefMI->getOpcode() == PadVectorOpcode) {
4406+
SourceVecReg = VecDefMI->getOperand(1).getReg();
4407+
VecDefMI = MRI.getVRegDef(SourceVecReg);
4408+
}
4409+
4410+
// Check for G_BITCAST (or direct vector if no bitcast needed)
4411+
Register LoadVecReg = SourceVecReg;
4412+
if (VecDefMI->getOpcode() == TargetOpcode::G_BITCAST)
4413+
LoadVecReg = VecDefMI->getOperand(1).getReg();
4414+
4415+
MachineInstr *LoadMI = MRI.getVRegDef(LoadVecReg);
4416+
4417+
// Check if it's a load
4418+
if (LoadMI->getOpcode() != TargetOpcode::G_LOAD)
4419+
return false;
4420+
4421+
// Check if the load is unaligned relative to the vector's total size
4422+
if (LoadMI->memoperands_empty())
4423+
return false;
4424+
4425+
const MachineMemOperand *MMO = LoadMI->memoperands().front();
4426+
const LLT LoadVecTy = MRI.getType(LoadVecReg);
4427+
const unsigned LoadVecSizeInBytes = LoadVecTy.getSizeInBytes();
4428+
// Vector is unaligned if alignment < vector size
4429+
// This allows extracting elements when the vector load itself is unaligned
4430+
if (MMO->getAlign().value() >= LoadVecSizeInBytes)
4431+
return false;
4432+
4433+
const unsigned ElemSizeInBytes = ElemSize / 8;
4434+
4435+
// Check that the loaded vector is only used by extracts (through bitcast and
4436+
// pad). The helper function will automatically traverse through bitcasts.
4437+
const Register LoadDstReg = LoadMI->getOperand(0).getReg();
4438+
4439+
if (!areLoadUsesValidForExtractCombine(LoadDstReg, ZExtExtractOpcode,
4440+
SExtExtractOpcode, PadVectorOpcode,
4441+
MRI))
4442+
return false;
4443+
4444+
// All checks passed, we can combine
4445+
MatchInfo = [=, &MI, &MRI, &Observer](MachineIRBuilder &B) {
4446+
const Register PtrReg = LoadMI->getOperand(1).getReg();
4447+
const LLT S20 = LLT::scalar(20);
4448+
4449+
// Calculate byte offset: Index * ElemSizeInBytes
4450+
const int64_t ByteOffset = Index * ElemSizeInBytes;
4451+
4452+
// Set insertion point right after the original vector load
4453+
if (LoadMI->getNextNode())
4454+
B.setInstr(*LoadMI->getNextNode());
4455+
else
4456+
B.setInsertPt(*LoadMI->getParent(), LoadMI->getParent()->end());
4457+
B.setDebugLoc(LoadMI->getDebugLoc());
4458+
4459+
// Create offset constant and pointer add
4460+
const Register OffsetReg = B.buildConstant(S20, ByteOffset).getReg(0);
4461+
const Register NewPtrReg =
4462+
B.buildPtrAdd(MRI.getType(PtrReg), PtrReg, OffsetReg).getReg(0);
4463+
4464+
// Calculate alignment for scalar load based on original vector load
4465+
// alignment If the byte offset is aligned to the original alignment, use
4466+
// it; otherwise use 1
4467+
const unsigned OrigAlign = MMO->getAlign().value();
4468+
const unsigned ScalarAlign = (ByteOffset % OrigAlign == 0) ? OrigAlign : 1;
4469+
4470+
// Create new scalar load with derived alignment
4471+
MachineFunction &MF = B.getMF();
4472+
MachineMemOperand *NewMMO =
4473+
MF.getMachineMemOperand(MMO->getPointerInfo(), MMO->getFlags(),
4474+
ElemSizeInBytes, Align(ScalarAlign));
4475+
4476+
const Register LoadResultReg = MRI.createGenericVirtualRegister(ElemTy);
4477+
B.buildLoad(LoadResultReg, NewPtrReg, *NewMMO);
4478+
4479+
// Now set insertion point at the extract position for the copy/extension
4480+
B.setInstr(MI);
4481+
4482+
// Handle the result based on the original opcode
4483+
if (IsZExtExtract || IsSExtExtract) {
4484+
// Need to extend to s32
4485+
const Register DstReg = MI.getOperand(0).getReg();
4486+
if (IsZExtExtract)
4487+
B.buildZExt(DstReg, LoadResultReg);
4488+
else
4489+
B.buildSExt(DstReg, LoadResultReg);
4490+
} else {
4491+
// G_EXTRACT_VECTOR_ELT - check if there's a G_ZEXT or G_SEXT user
4492+
const Register DstReg = MI.getOperand(0).getReg();
4493+
if (MRI.hasOneNonDBGUse(DstReg)) {
4494+
MachineInstr *UserMI = &*MRI.use_instr_nodbg_begin(DstReg);
4495+
const unsigned UserOpcode = UserMI->getOpcode();
4496+
if (UserOpcode == TargetOpcode::G_ZEXT ||
4497+
UserOpcode == TargetOpcode::G_SEXT) {
4498+
// Combine the extract and ext
4499+
const Register ExtDstReg = UserMI->getOperand(0).getReg();
4500+
if (UserOpcode == TargetOpcode::G_ZEXT)
4501+
B.buildZExt(ExtDstReg, LoadResultReg);
4502+
else
4503+
B.buildSExt(ExtDstReg, LoadResultReg);
4504+
Observer.erasingInstr(*UserMI);
4505+
UserMI->eraseFromParent();
4506+
Observer.erasingInstr(MI);
4507+
MI.eraseFromParent();
4508+
return;
4509+
}
4510+
}
4511+
// Just copy the result
4512+
B.buildCopy(DstReg, LoadResultReg);
4513+
}
4514+
4515+
Observer.erasingInstr(MI);
4516+
MI.eraseFromParent();
4517+
};
4518+
4519+
return true;
4520+
}

llvm/lib/Target/AIE/AIECombinerHelper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,10 @@ bool matchExtractVecEltAssertBcst(MachineInstr &MI, MachineRegisterInfo &MRI,
308308
const AIEBaseInstrInfo &TII,
309309
GISelChangeObserver &Observer,
310310
BuildFnTy &MatchInfo);
311+
312+
bool matchUnalignedExtractLoad(MachineInstr &MI, MachineRegisterInfo &MRI,
313+
GISelChangeObserver &Observer,
314+
BuildFnTy &MatchInfo);
311315
} // namespace llvm
312316

313317
#endif

0 commit comments

Comments
 (0)