Skip to content

Commit e684035

Browse files
committed
[AIEX] Add a combiner to change vector load element type based on alignment
In this case, we can improve the legalized code.
1 parent e6e7fd7 commit e684035

File tree

5 files changed

+760
-5
lines changed

5 files changed

+760
-5
lines changed

llvm/lib/Target/AIE/AIECombine.td

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,12 @@ def combine_unaligned_extract_load : GICombineRule<
258258
[{ return matchUnalignedExtractLoad(*${root}, MRI, Observer, ${matchinfo}); }]),
259259
(apply [{ Helper.applyBuildFnNoErase(*${root}, ${matchinfo}); }])>;
260260

261+
def combine_unaligned_vector_load : GICombineRule<
262+
(defs root:$root, build_fn_matchinfo:$matchinfo),
263+
(match (wip_match_opcode G_LOAD): $root,
264+
[{ return matchUnalignedVectorLoad(*${root}, MRI, Observer, ${matchinfo}); }]),
265+
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;
266+
261267
// AIE-specifc combines (currently shared by AIE2 and AIE2P).
262268
def aie_additional_combines : GICombineGroup<[
263269
combine_unpad_vector,
@@ -281,7 +287,8 @@ def aie_additional_combines : GICombineGroup<[
281287
combine_peel_memset,
282288
combine_pack_stores_into_memset,
283289
combine_trunc_load,
284-
combine_unaligned_extract_load
290+
combine_unaligned_extract_load,
291+
combine_unaligned_vector_load
285292
]>;
286293

287294
// AIE2P-specific combines.

llvm/lib/Target/AIE/AIECombinerHelper.cpp

Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4499,3 +4499,118 @@ bool llvm::matchUnalignedExtractLoad(MachineInstr &ExtractMI,
44994499

45004500
return true;
45014501
}
4502+
4503+
/// Match unaligned vector loads and transform them to use a better-aligned
4504+
/// element type based on the actual alignment.
4505+
/// Pattern:
4506+
/// %vec:_(<32 x s16>) = G_LOAD %ptr(p0) :: (align 4)
4507+
/// Converts to:
4508+
/// %vec_new:_(<16 x s32>) = G_LOAD %ptr(p0) :: (align 4)
4509+
/// %vec:_(<32 x s16>) = G_BITCAST %vec_new(<16 x s32>)
4510+
bool llvm::matchUnalignedVectorLoad(MachineInstr &LoadMI,
4511+
MachineRegisterInfo &MRI,
4512+
GISelChangeObserver &Observer,
4513+
BuildFnTy &MatchInfo) {
4514+
assert(LoadMI.getOpcode() == TargetOpcode::G_LOAD && "Expected G_LOAD");
4515+
4516+
// Get load information
4517+
const Register DstReg = LoadMI.getOperand(0).getReg();
4518+
const LLT DstTy = MRI.getType(DstReg);
4519+
4520+
// Only process vector loads
4521+
if (!DstTy.isVector())
4522+
return false;
4523+
4524+
// Check memory operand for alignment
4525+
if (LoadMI.memoperands_empty())
4526+
return false;
4527+
4528+
const MachineMemOperand *MMO = LoadMI.memoperands().front();
4529+
const unsigned Alignment = MMO->getAlign().value();
4530+
4531+
// Skip if the vector is already well-aligned (alignment >= vector size)
4532+
const unsigned VecSizeInBytes = DstTy.getSizeInBytes();
4533+
if (Alignment >= VecSizeInBytes)
4534+
return false;
4535+
4536+
// Get element type information
4537+
const LLT ElemTy = DstTy.getElementType();
4538+
const unsigned ElemSize = ElemTy.getSizeInBits();
4539+
4540+
// Skip if the load is only used for extracts - let matchUnalignedExtractLoad
4541+
// handle it This prevents the two combiners from competing for the same
4542+
// opportunities
4543+
const MachineFunction &MF = *LoadMI.getMF();
4544+
const AIEBaseInstrInfo &TII =
4545+
*static_cast<const AIEBaseInstrInfo *>(MF.getSubtarget().getInstrInfo());
4546+
const unsigned ZExtExtractOpcode =
4547+
TII.getGenericExtractVectorEltOpcode(false);
4548+
const unsigned SExtExtractOpcode = TII.getGenericExtractVectorEltOpcode(true);
4549+
const unsigned PadVectorOpcode = TII.getGenericPadVectorOpcode();
4550+
4551+
if (areLoadUsesValidForExtractCombine(
4552+
DstReg, ZExtExtractOpcode, SExtExtractOpcode, PadVectorOpcode, MRI))
4553+
return false;
4554+
4555+
// Skip if the load has a single user that is a G_STORE with the same
4556+
// alignment This case can be perfectly scalarized during legalization
4557+
if (MRI.hasOneNonDBGUse(DstReg)) {
4558+
const MachineInstr *UserMI = &*MRI.use_instr_nodbg_begin(DstReg);
4559+
if (UserMI->getOpcode() == TargetOpcode::G_STORE) {
4560+
const GStore *StoreMI = cast<GStore>(UserMI);
4561+
if (!StoreMI->memoperands_empty()) {
4562+
const MachineMemOperand *StoreMMO = StoreMI->memoperands().front();
4563+
// If store has the same alignment as the load, skip
4564+
if (StoreMMO->getAlign().value() == Alignment)
4565+
return false;
4566+
}
4567+
}
4568+
}
4569+
4570+
// Only handle s8 and s16 element types that can be promoted to s32
4571+
if (ElemSize != 8 && ElemSize != 16)
4572+
return false;
4573+
4574+
// Determine the optimal element type based on alignment
4575+
unsigned NewElemSize = 0;
4576+
if (Alignment >= 8 && ElemSize < 64) {
4577+
NewElemSize = 64;
4578+
} else if (Alignment >= 4 && ElemSize < 32) {
4579+
NewElemSize = 32;
4580+
} else if (Alignment >= 2 && ElemSize < 16) {
4581+
NewElemSize = 16;
4582+
} else {
4583+
// Alignment doesn't allow for a better element type
4584+
return false;
4585+
}
4586+
4587+
// Check if the vector size is compatible with the new element size
4588+
const unsigned VecSizeInBits = DstTy.getSizeInBits();
4589+
if (VecSizeInBits % NewElemSize != 0)
4590+
return false;
4591+
4592+
// Calculate new number of elements
4593+
const unsigned NewNumElems = VecSizeInBits / NewElemSize;
4594+
4595+
// Capture the pointer register before creating the lambda
4596+
const Register PtrReg = LoadMI.getOperand(1).getReg();
4597+
4598+
MatchInfo = [=, &MRI, &Observer](MachineIRBuilder &B) {
4599+
MachineFunction &MF = B.getMF();
4600+
4601+
// Create the new vector type with better-aligned elements
4602+
const LLT NewVecTy = LLT::fixed_vector(NewNumElems, NewElemSize);
4603+
const Register NewLoadReg = MRI.createGenericVirtualRegister(NewVecTy);
4604+
4605+
// Create a new MMO with the same properties but updated type
4606+
MachineMemOperand *NewMMO = MF.getMachineMemOperand(
4607+
MMO->getPointerInfo(), MMO->getFlags(), NewVecTy, MMO->getAlign());
4608+
4609+
Observer.createdInstr(*B.buildLoad(NewLoadReg, PtrReg, *NewMMO));
4610+
4611+
// Bitcast back to the original type
4612+
Observer.createdInstr(*B.buildBitcast(DstReg, NewLoadReg));
4613+
};
4614+
4615+
return true;
4616+
}

llvm/lib/Target/AIE/AIECombinerHelper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,6 +313,10 @@ bool matchUnalignedExtractLoad(MachineInstr &ExtractMI,
313313
MachineRegisterInfo &MRI,
314314
GISelChangeObserver &Observer,
315315
BuildFnTy &MatchInfo);
316+
317+
bool matchUnalignedVectorLoad(MachineInstr &LoadMI, MachineRegisterInfo &MRI,
318+
GISelChangeObserver &Observer,
319+
BuildFnTy &MatchInfo);
316320
} // namespace llvm
317321

318322
#endif

0 commit comments

Comments
 (0)