Skip to content

Commit e6e7fd7

Browse files
committed
[AIEX] Add a combiner to handle extract from unaligned vector load
In this case, we can load the scalar value directly instead of building a full vector (legalizer will scalarize this load anyway) to extract.
1 parent 2733f2e commit e6e7fd7

File tree

4 files changed

+1046
-2
lines changed

4 files changed

+1046
-2
lines changed

llvm/lib/Target/AIE/AIECombine.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,6 +252,12 @@ def combine_trunc_load : GICombineRule<
252252
[{ return matchNarrowTruncLoad(*${root}, MRI, Helper, Observer, ${matchinfo}); }]),
253253
(apply [{ Helper.applyBuildFnNoErase(*${root}, ${matchinfo}); }])>;
254254

255+
def combine_unaligned_extract_load : GICombineRule<
256+
(defs root:$root, build_fn_matchinfo:$matchinfo),
257+
(match (wip_match_opcode G_EXTRACT_VECTOR_ELT, G_AIE_ZEXT_EXTRACT_VECTOR_ELT, G_AIE_SEXT_EXTRACT_VECTOR_ELT): $root,
258+
[{ return matchUnalignedExtractLoad(*${root}, MRI, Observer, ${matchinfo}); }]),
259+
(apply [{ Helper.applyBuildFnNoErase(*${root}, ${matchinfo}); }])>;
260+
255261
// AIE-specifc combines (currently shared by AIE2 and AIE2P).
256262
def aie_additional_combines : GICombineGroup<[
257263
combine_unpad_vector,
@@ -274,7 +280,8 @@ def aie_additional_combines : GICombineGroup<[
274280
combine_align_memset,
275281
combine_peel_memset,
276282
combine_pack_stores_into_memset,
277-
combine_trunc_load
283+
combine_trunc_load,
284+
combine_unaligned_extract_load
278285
]>;
279286

280287
// AIE2P-specific combines.
@@ -408,4 +415,3 @@ def AIE2PPostLegalizerCustomCombiner
408415
combine_add_vector_elt_undef,
409416
]> {
410417
}
411-

llvm/lib/Target/AIE/AIECombinerHelper.cpp

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4292,3 +4292,210 @@ bool llvm::matchExtractVecEltAssertBcst(MachineInstr &MI,
42924292

42934293
return true;
42944294
}
4295+
4296+
/// Helper function to recursively check if all uses of a register are valid
4297+
/// for the unaligned extract load combiner.
4298+
/// Automatically traverses through bitcasts to validate all usage patterns.
4299+
/// Valid terminal uses are: direct extracts or pad vector operations (with use
4300+
/// check).
4301+
static bool areLoadUsesValidForExtractCombine(Register Reg,
4302+
unsigned ZExtExtractOpcode,
4303+
unsigned SExtExtractOpcode,
4304+
unsigned PadVectorOpcode,
4305+
MachineRegisterInfo &MRI) {
4306+
4307+
auto IsValidExtractOpcode = [&](unsigned Opcode) {
4308+
return Opcode == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
4309+
Opcode == ZExtExtractOpcode || Opcode == SExtExtractOpcode;
4310+
};
4311+
4312+
for (const MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
4313+
const unsigned UseOpcode = Use.getOpcode();
4314+
4315+
if (UseOpcode == TargetOpcode::G_BITCAST) {
4316+
// Recursively check bitcast uses
4317+
const Register BitcastDst = Use.getOperand(0).getReg();
4318+
if (!areLoadUsesValidForExtractCombine(BitcastDst, ZExtExtractOpcode,
4319+
SExtExtractOpcode, PadVectorOpcode,
4320+
MRI))
4321+
return false;
4322+
continue;
4323+
}
4324+
4325+
if (IsValidExtractOpcode(UseOpcode)) {
4326+
// Direct extract is valid (plain, zext, or sext)
4327+
continue;
4328+
}
4329+
4330+
if (UseOpcode == PadVectorOpcode) {
4331+
// Pad is valid if only used by extracts
4332+
const Register PadDst = Use.getOperand(0).getReg();
4333+
for (const MachineInstr &PadUse : MRI.use_nodbg_instructions(PadDst)) {
4334+
if (!IsValidExtractOpcode(PadUse.getOpcode()))
4335+
return false;
4336+
}
4337+
continue;
4338+
}
4339+
4340+
// Invalid use
4341+
return false;
4342+
}
4343+
4344+
return true;
4345+
}
4346+
4347+
/// Match unaligned vector loads that are only used for extracting elements
4348+
/// and convert them to direct scalar loads.
4349+
/// Supports s8, s16 and s32 element extractions from various vector
4350+
/// configurations. Pattern:
4351+
/// %vec:_(<N x sM>) = G_LOAD %ptr(p0) :: (align < M/8)
4352+
/// %bitcast:_(<K x sX>) = G_BITCAST %vec
4353+
/// %idx:_(s32) = G_CONSTANT i32 N
4354+
/// %elt:_(sX) = G_EXTRACT_VECTOR_ELT %bitcast, %idx
4355+
/// Or with G_AIE_PAD_VECTOR_UNDEF:
4356+
/// %vec = G_LOAD %ptr :: (unaligned)
4357+
/// %bitcast = G_BITCAST %vec
4358+
/// %padded = G_AIE_PAD_VECTOR_UNDEF %bitcast
4359+
/// %result:_(s32) = G_AIE_[Z/S]EXT_EXTRACT_VECTOR_ELT %padded, %idx
4360+
/// Converts to:
4361+
/// %offset:_(s20) = G_CONSTANT i20 (N * sizeof(sX))
4362+
/// %new_ptr:_(p0) = G_PTR_ADD %ptr, %offset
4363+
/// %elt:_(sX) = G_LOAD %new_ptr :: (align 1)
4364+
/// %result:_(s32) = G_[Z/S]EXT %elt
4365+
bool llvm::matchUnalignedExtractLoad(MachineInstr &ExtractMI,
4366+
MachineRegisterInfo &MRI,
4367+
GISelChangeObserver &Observer,
4368+
BuildFnTy &MatchInfo) {
4369+
const MachineFunction &MF = *ExtractMI.getMF();
4370+
const AIEBaseInstrInfo &TII =
4371+
*static_cast<const AIEBaseInstrInfo *>(MF.getSubtarget().getInstrInfo());
4372+
4373+
const unsigned Opcode = ExtractMI.getOpcode();
4374+
const unsigned ZExtExtractOpcode =
4375+
TII.getGenericExtractVectorEltOpcode(false);
4376+
const unsigned SExtExtractOpcode = TII.getGenericExtractVectorEltOpcode(true);
4377+
const unsigned PadVectorOpcode = TII.getGenericPadVectorOpcode();
4378+
4379+
const bool IsZExtExtract = (Opcode == ZExtExtractOpcode);
4380+
const bool IsSExtExtract = (Opcode == SExtExtractOpcode);
4381+
const bool IsPlainExtract = (Opcode == TargetOpcode::G_EXTRACT_VECTOR_ELT);
4382+
4383+
if (!IsZExtExtract && !IsSExtExtract && !IsPlainExtract)
4384+
return false;
4385+
4386+
// Get the index operand
4387+
const Register IdxReg = ExtractMI.getOperand(2).getReg();
4388+
const auto IdxCst = getIConstantVRegValWithLookThrough(IdxReg, MRI);
4389+
if (!IdxCst)
4390+
return false;
4391+
const int64_t Index = IdxCst->Value.getSExtValue();
4392+
4393+
// Get the vector operand
4394+
const Register VecReg = ExtractMI.getOperand(1).getReg();
4395+
const LLT VecTy = MRI.getType(VecReg);
4396+
4397+
// Check if vector has extractable element types (s8, s16, or s32)
4398+
if (!VecTy.isVector())
4399+
return false;
4400+
4401+
const LLT ElemTy = VecTy.getElementType();
4402+
const unsigned ElemSize = ElemTy.getSizeInBits();
4403+
if (ElemSize != 8 && ElemSize != 16 && ElemSize != 32)
4404+
return false;
4405+
4406+
// Trace back through G_AIE_PAD_VECTOR_UNDEF if present
4407+
MachineInstr *VecDefMI = MRI.getVRegDef(VecReg);
4408+
Register SourceVecReg = VecReg;
4409+
4410+
if (VecDefMI->getOpcode() == PadVectorOpcode) {
4411+
SourceVecReg = VecDefMI->getOperand(1).getReg();
4412+
VecDefMI = MRI.getVRegDef(SourceVecReg);
4413+
}
4414+
4415+
// Check for G_BITCAST (or direct vector if no bitcast needed)
4416+
Register LoadVecReg = SourceVecReg;
4417+
if (VecDefMI->getOpcode() == TargetOpcode::G_BITCAST)
4418+
LoadVecReg = VecDefMI->getOperand(1).getReg();
4419+
4420+
MachineInstr *LoadMI = MRI.getVRegDef(LoadVecReg);
4421+
4422+
// Check if it's a load
4423+
if (LoadMI->getOpcode() != TargetOpcode::G_LOAD)
4424+
return false;
4425+
4426+
// Check if the load is unaligned relative to the vector's total size
4427+
if (LoadMI->memoperands_empty())
4428+
return false;
4429+
4430+
const MachineMemOperand *MMO = LoadMI->memoperands().front();
4431+
const LLT LoadVecTy = MRI.getType(LoadVecReg);
4432+
const unsigned LoadVecSizeInBytes = LoadVecTy.getSizeInBytes();
4433+
// Vector is unaligned if alignment < vector size
4434+
// This allows extracting elements when the vector load itself is unaligned
4435+
if (MMO->getAlign().value() >= LoadVecSizeInBytes)
4436+
return false;
4437+
4438+
// Check that the loaded vector is only used by extracts (through bitcast and
4439+
// pad). The helper function will automatically traverse through bitcasts.
4440+
const Register LoadDstReg = LoadMI->getOperand(0).getReg();
4441+
4442+
if (!areLoadUsesValidForExtractCombine(LoadDstReg, ZExtExtractOpcode,
4443+
SExtExtractOpcode, PadVectorOpcode,
4444+
MRI))
4445+
return false;
4446+
4447+
// All checks passed, we can combine
4448+
MatchInfo = [=, &ExtractMI, &MRI, &Observer](MachineIRBuilder &B) {
4449+
const Register PtrReg = LoadMI->getOperand(1).getReg();
4450+
const LLT S20 = LLT::scalar(20);
4451+
4452+
const unsigned ElemSizeInBytes = ElemSize / 8;
4453+
const int64_t ByteOffset = Index * ElemSizeInBytes;
4454+
4455+
// Set insertion point right after the original vector load
4456+
B.setInsertPt(*LoadMI->getParent(), std::next(LoadMI->getIterator()));
4457+
B.setDebugLoc(LoadMI->getDebugLoc());
4458+
4459+
// Create offset constant and pointer add
4460+
const Register OffsetReg = B.buildConstant(S20, ByteOffset).getReg(0);
4461+
const Register NewPtrReg =
4462+
B.buildPtrAdd(MRI.getType(PtrReg), PtrReg, OffsetReg).getReg(0);
4463+
4464+
// Calculate alignment for scalar load based on original vector load
4465+
// alignment using GCD to find the maximum provable alignment
4466+
const unsigned OrigAlign = MMO->getAlign().value();
4467+
const unsigned ScalarAlign = std::gcd(OrigAlign, OrigAlign + ByteOffset);
4468+
4469+
// Create new scalar load with derived alignment
4470+
MachineFunction &MF = B.getMF();
4471+
MachineMemOperand *NewMMO =
4472+
MF.getMachineMemOperand(MMO->getPointerInfo(), MMO->getFlags(),
4473+
ElemSizeInBytes, Align(ScalarAlign));
4474+
4475+
const Register LoadResultReg = MRI.createGenericVirtualRegister(ElemTy);
4476+
B.buildLoad(LoadResultReg, NewPtrReg, *NewMMO);
4477+
4478+
// Now set insertion point at the extract position for the copy/extension
4479+
B.setInstr(ExtractMI);
4480+
4481+
// Handle the result based on the original opcode
4482+
if (IsZExtExtract || IsSExtExtract) {
4483+
// Need to extend to s32
4484+
const Register DstReg = ExtractMI.getOperand(0).getReg();
4485+
if (IsZExtExtract)
4486+
B.buildZExt(DstReg, LoadResultReg);
4487+
else
4488+
B.buildSExt(DstReg, LoadResultReg);
4489+
} else {
4490+
// G_EXTRACT_VECTOR_ELT
4491+
const Register DstReg = ExtractMI.getOperand(0).getReg();
4492+
// Just copy the result
4493+
B.buildCopy(DstReg, LoadResultReg);
4494+
}
4495+
4496+
Observer.erasingInstr(ExtractMI);
4497+
ExtractMI.eraseFromParent();
4498+
};
4499+
4500+
return true;
4501+
}

llvm/lib/Target/AIE/AIECombinerHelper.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -308,6 +308,11 @@ bool matchExtractVecEltAssertBcst(MachineInstr &MI, MachineRegisterInfo &MRI,
308308
const AIEBaseInstrInfo &TII,
309309
GISelChangeObserver &Observer,
310310
BuildFnTy &MatchInfo);
311+
312+
bool matchUnalignedExtractLoad(MachineInstr &ExtractMI,
313+
MachineRegisterInfo &MRI,
314+
GISelChangeObserver &Observer,
315+
BuildFnTy &MatchInfo);
311316
} // namespace llvm
312317

313318
#endif

0 commit comments

Comments
 (0)