Skip to content

Commit 20e7808

Browse files
committed
[AIEX] Add a combiner to change vector load element type based on alignment
In this case, we can improve the legalized code.
1 parent b5d7a42 commit 20e7808

File tree

5 files changed

+760
-5
lines changed

5 files changed

+760
-5
lines changed

llvm/lib/Target/AIE/AIECombine.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,12 @@ def combine_unaligned_extract_load : GICombineRule<
258258
[{ return matchUnalignedExtractLoad(*${root}, MRI, Observer, ${matchinfo}); }]),
259259
(apply [{ Helper.applyBuildFnNoErase(*${root}, ${matchinfo}); }])>;
260260

261+
def combine_unaligned_vector_load : GICombineRule<
262+
(defs root:$root, build_fn_matchinfo:$matchinfo),
263+
(match (wip_match_opcode G_LOAD): $root,
264+
[{ return matchUnalignedVectorLoad(*${root}, MRI, Observer, ${matchinfo}); }]),
265+
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
266+
261267
// AIE-specifc combines (currently shared by AIE2 and AIE2P).
262268
def aie_additional_combines : GICombineGroup<[
263269
combine_unpad_vector,
@@ -281,7 +287,8 @@ def aie_additional_combines : GICombineGroup<[
281287
combine_peel_memset,
282288
combine_pack_stores_into_memset,
283289
combine_trunc_load,
284-
combine_unaligned_extract_load
290+
combine_unaligned_extract_load,
291+
combine_unaligned_vector_load
285292
]>;
286293

287294
// AIE2P-specific combines.

llvm/lib/Target/AIE/AIECombinerHelper.cpp

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4498,3 +4498,118 @@ bool llvm::matchUnalignedExtractLoad(MachineInstr &ExtractMI,
44984498

44994499
return true;
45004500
}
4501+
4502+
/// Match unaligned vector loads and transform them to use a better-aligned
4503+
/// element type based on the actual alignment.
4504+
/// Pattern:
4505+
/// %vec:_(<32 x s16>) = G_LOAD %ptr(p0) :: (align 4)
4506+
/// Converts to:
4507+
/// %vec_new:_(<16 x s32>) = G_LOAD %ptr(p0) :: (align 4)
4508+
/// %vec:_(<32 x s16>) = G_BITCAST %vec_new(<16 x s32>)
4509+
bool llvm::matchUnalignedVectorLoad(MachineInstr &LoadMI,
4510+
MachineRegisterInfo &MRI,
4511+
GISelChangeObserver &Observer,
4512+
BuildFnTy &MatchInfo) {
4513+
assert(LoadMI.getOpcode() == TargetOpcode::G_LOAD && "Expected G_LOAD");
4514+
4515+
// Get load information
4516+
const Register DstReg = LoadMI.getOperand(0).getReg();
4517+
const LLT DstTy = MRI.getType(DstReg);
4518+
4519+
// Only process vector loads
4520+
if (!DstTy.isVector())
4521+
return false;
4522+
4523+
// Check memory operand for alignment
4524+
if (LoadMI.memoperands_empty())
4525+
return false;
4526+
4527+
const MachineMemOperand *MMO = LoadMI.memoperands().front();
4528+
const unsigned Alignment = MMO->getAlign().value();
4529+
4530+
// Skip if the vector is already well-aligned (alignment >= vector size)
4531+
const unsigned VecSizeInBytes = DstTy.getSizeInBytes();
4532+
if (Alignment >= VecSizeInBytes)
4533+
return false;
4534+
4535+
// Get element type information
4536+
const LLT ElemTy = DstTy.getElementType();
4537+
const unsigned ElemSize = ElemTy.getSizeInBits();
4538+
4539+
// Skip if the load is only used for extracts - let matchUnalignedExtractLoad
4540+
// handle it This prevents the two combiners from competing for the same
4541+
// opportunities
4542+
const MachineFunction &MF = *LoadMI.getMF();
4543+
const AIEBaseInstrInfo &TII =
4544+
*static_cast<const AIEBaseInstrInfo *>(MF.getSubtarget().getInstrInfo());
4545+
const unsigned ZExtExtractOpcode =
4546+
TII.getGenericExtractVectorEltOpcode(false);
4547+
const unsigned SExtExtractOpcode = TII.getGenericExtractVectorEltOpcode(true);
4548+
const unsigned PadVectorOpcode = TII.getGenericPadVectorOpcode();
4549+
4550+
if (areLoadUsesValidForExtractCombine(
4551+
DstReg, ZExtExtractOpcode, SExtExtractOpcode, PadVectorOpcode, MRI))
4552+
return false;
4553+
4554+
// Skip if the load has a single user that is a G_STORE with the same
4555+
// alignment This case can be perfectly scalarized during legalization
4556+
if (MRI.hasOneNonDBGUse(DstReg)) {
4557+
const MachineInstr *UserMI = &*MRI.use_instr_nodbg_begin(DstReg);
4558+
if (UserMI->getOpcode() == TargetOpcode::G_STORE) {
4559+
const GStore *StoreMI = cast<GStore>(UserMI);
4560+
if (!StoreMI->memoperands_empty()) {
4561+
const MachineMemOperand *StoreMMO = StoreMI->memoperands().front();
4562+
// If store has the same alignment as the load, skip
4563+
if (StoreMMO->getAlign().value() == Alignment)
4564+
return false;
4565+
}
4566+
}
4567+
}
4568+
4569+
// Only handle s8 and s16 element types that can be promoted to s32
4570+
if (ElemSize != 8 && ElemSize != 16)
4571+
return false;
4572+
4573+
// Determine the optimal element type based on alignment
4574+
unsigned NewElemSize = 0;
4575+
if (Alignment >= 8 && ElemSize < 64) {
4576+
NewElemSize = 64;
4577+
} else if (Alignment >= 4 && ElemSize < 32) {
4578+
NewElemSize = 32;
4579+
} else if (Alignment >= 2 && ElemSize < 16) {
4580+
NewElemSize = 16;
4581+
} else {
4582+
// Alignment doesn't allow for a better element type
4583+
return false;
4584+
}
4585+
4586+
// Check if the vector size is compatible with the new element size
4587+
const unsigned VecSizeInBits = DstTy.getSizeInBits();
4588+
if (VecSizeInBits % NewElemSize != 0)
4589+
return false;
4590+
4591+
// Calculate new number of elements
4592+
const unsigned NewNumElems = VecSizeInBits / NewElemSize;
4593+
4594+
// Capture the pointer register before creating the lambda
4595+
const Register PtrReg = LoadMI.getOperand(1).getReg();
4596+
4597+
MatchInfo = [=, &MRI, &Observer](MachineIRBuilder &B) {
4598+
MachineFunction &MF = B.getMF();
4599+
4600+
// Create the new vector type with better-aligned elements
4601+
const LLT NewVecTy = LLT::fixed_vector(NewNumElems, NewElemSize);
4602+
const Register NewLoadReg = MRI.createGenericVirtualRegister(NewVecTy);
4603+
4604+
// Create a new MMO with the same properties but updated type
4605+
MachineMemOperand *NewMMO = MF.getMachineMemOperand(
4606+
MMO->getPointerInfo(), MMO->getFlags(), NewVecTy, MMO->getAlign());
4607+
4608+
Observer.createdInstr(*B.buildLoad(NewLoadReg, PtrReg, *NewMMO));
4609+
4610+
// Bitcast back to the original type
4611+
Observer.createdInstr(*B.buildBitcast(DstReg, NewLoadReg));
4612+
};
4613+
4614+
return true;
4615+
}

llvm/lib/Target/AIE/AIECombinerHelper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ bool matchUnalignedExtractLoad(MachineInstr &ExtractMI,
313313
MachineRegisterInfo &MRI,
314314
GISelChangeObserver &Observer,
315315
BuildFnTy &MatchInfo);
316+
317+
bool matchUnalignedVectorLoad(MachineInstr &LoadMI, MachineRegisterInfo &MRI,
318+
GISelChangeObserver &Observer,
319+
BuildFnTy &MatchInfo);
316320
} // namespace llvm
317321

318322
#endif

0 commit comments

Comments
 (0)