Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions llvm/lib/Target/AIE/AIECombine.td
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,18 @@ def combine_trunc_load : GICombineRule<
[{ return matchNarrowTruncLoad(*${root}, MRI, Helper, Observer, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFnNoErase(*${root}, ${matchinfo}); }])>;

def combine_unaligned_extract_load : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (wip_match_opcode G_EXTRACT_VECTOR_ELT, G_AIE_ZEXT_EXTRACT_VECTOR_ELT, G_AIE_SEXT_EXTRACT_VECTOR_ELT): $root,
[{ return matchUnalignedExtractLoad(*${root}, MRI, Observer, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFnNoErase(*${root}, ${matchinfo}); }])>;

def combine_unaligned_vector_load : GICombineRule<
(defs root:$root, build_fn_matchinfo:$matchinfo),
(match (wip_match_opcode G_LOAD): $root,
[{ return matchUnalignedVectorLoad(*${root}, MRI, Observer, ${matchinfo}); }]),
(apply [{ Helper.applyBuildFn(*${root}, ${matchinfo}); }])>;

// AIE-specifc combines (currently shared by AIE2 and AIE2P).
def aie_additional_combines : GICombineGroup<[
combine_unpad_vector,
Expand All @@ -274,7 +286,9 @@ def aie_additional_combines : GICombineGroup<[
combine_align_memset,
combine_peel_memset,
combine_pack_stores_into_memset,
combine_trunc_load
combine_trunc_load,
combine_unaligned_extract_load,
combine_unaligned_vector_load
]>;

// AIE2P-specific combines.
Expand Down Expand Up @@ -408,4 +422,3 @@ def AIE2PPostLegalizerCustomCombiner
combine_add_vector_elt_undef,
]> {
}

321 changes: 321 additions & 0 deletions llvm/lib/Target/AIE/AIECombinerHelper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4292,3 +4292,324 @@ bool llvm::matchExtractVecEltAssertBcst(MachineInstr &MI,

return true;
}

/// Helper function to recursively check if all uses of a register are valid
/// for the unaligned extract load combiner.
/// Automatically traverses through bitcasts to validate all usage patterns.
/// Valid terminal uses are: direct extracts or pad vector operations (with use
/// check).
static bool areLoadUsesValidForExtractCombine(Register Reg,
unsigned ZExtExtractOpcode,
unsigned SExtExtractOpcode,
unsigned PadVectorOpcode,
MachineRegisterInfo &MRI) {

auto IsValidExtractOpcode = [&](unsigned Opcode) {
return Opcode == TargetOpcode::G_EXTRACT_VECTOR_ELT ||
Opcode == ZExtExtractOpcode || Opcode == SExtExtractOpcode;
};

for (const MachineInstr &Use : MRI.use_nodbg_instructions(Reg)) {
const unsigned UseOpcode = Use.getOpcode();

if (UseOpcode == TargetOpcode::G_BITCAST) {
// Recursively check bitcast uses
const Register BitcastDst = Use.getOperand(0).getReg();
if (!areLoadUsesValidForExtractCombine(BitcastDst, ZExtExtractOpcode,
SExtExtractOpcode, PadVectorOpcode,
MRI))
return false;
continue;
}

if (IsValidExtractOpcode(UseOpcode)) {
// Direct extract is valid (plain, zext, or sext)
continue;
}

if (UseOpcode == PadVectorOpcode) {
// Pad is valid if only used by extracts
const Register PadDst = Use.getOperand(0).getReg();
for (const MachineInstr &PadUse : MRI.use_nodbg_instructions(PadDst)) {
if (!IsValidExtractOpcode(PadUse.getOpcode()))
return false;
}
continue;
}

// Invalid use
return false;
}

return true;
}

/// Match unaligned vector loads that are only used for extracting elements
/// and convert them to direct scalar loads.
/// Supports s8, s16 and s32 element extractions from various vector
/// configurations. Pattern:
/// %vec:_(<N x sM>) = G_LOAD %ptr(p0) :: (align < M/8)
/// %bitcast:_(<K x sX>) = G_BITCAST %vec
/// %idx:_(s32) = G_CONSTANT i32 N
/// %elt:_(sX) = G_EXTRACT_VECTOR_ELT %bitcast, %idx
/// Or with G_AIE_PAD_VECTOR_UNDEF:
/// %vec = G_LOAD %ptr :: (unaligned)
/// %bitcast = G_BITCAST %vec
/// %padded = G_AIE_PAD_VECTOR_UNDEF %bitcast
/// %result:_(s32) = G_AIE_[Z/S]EXT_EXTRACT_VECTOR_ELT %padded, %idx
/// Converts to:
/// %offset:_(s20) = G_CONSTANT i20 (N * sizeof(sX))
/// %new_ptr:_(p0) = G_PTR_ADD %ptr, %offset
/// %elt:_(sX) = G_LOAD %new_ptr :: (align 1)
/// %result:_(s32) = G_[Z/S]EXT %elt
bool llvm::matchUnalignedExtractLoad(MachineInstr &ExtractMI,
MachineRegisterInfo &MRI,
GISelChangeObserver &Observer,
BuildFnTy &MatchInfo) {
const MachineFunction &MF = *ExtractMI.getMF();
const AIEBaseInstrInfo &TII =
*static_cast<const AIEBaseInstrInfo *>(MF.getSubtarget().getInstrInfo());

const unsigned Opcode = ExtractMI.getOpcode();
const unsigned ZExtExtractOpcode =
TII.getGenericExtractVectorEltOpcode(false);
const unsigned SExtExtractOpcode = TII.getGenericExtractVectorEltOpcode(true);
const unsigned PadVectorOpcode = TII.getGenericPadVectorOpcode();

const bool IsZExtExtract = (Opcode == ZExtExtractOpcode);
const bool IsSExtExtract = (Opcode == SExtExtractOpcode);
const bool IsPlainExtract = (Opcode == TargetOpcode::G_EXTRACT_VECTOR_ELT);

if (!IsZExtExtract && !IsSExtExtract && !IsPlainExtract)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks as if the pattern's opcode check precludes this case?

return false;

// Get the index operand
const Register IdxReg = ExtractMI.getOperand(2).getReg();
const auto IdxCst = getIConstantVRegValWithLookThrough(IdxReg, MRI);
if (!IdxCst)
return false;
const int64_t Index = IdxCst->Value.getSExtValue();

// Get the vector operand
const Register VecReg = ExtractMI.getOperand(1).getReg();
const LLT VecTy = MRI.getType(VecReg);

// Check if vector has extractable element types (s8, s16, or s32)
if (!VecTy.isVector())
return false;

const LLT ElemTy = VecTy.getElementType();
const unsigned ElemSize = ElemTy.getSizeInBits();
if (ElemSize != 8 && ElemSize != 16 && ElemSize != 32)
return false;

// Trace back through G_AIE_PAD_VECTOR_UNDEF if present
MachineInstr *VecDefMI = MRI.getVRegDef(VecReg);
Register SourceVecReg = VecReg;

if (VecDefMI->getOpcode() == PadVectorOpcode) {
SourceVecReg = VecDefMI->getOperand(1).getReg();
VecDefMI = MRI.getVRegDef(SourceVecReg);
}

// Check for G_BITCAST (or direct vector if no bitcast needed)
Register LoadVecReg = SourceVecReg;
if (VecDefMI->getOpcode() == TargetOpcode::G_BITCAST)
LoadVecReg = VecDefMI->getOperand(1).getReg();

MachineInstr *LoadMI = MRI.getVRegDef(LoadVecReg);

// Check if it's a load
if (LoadMI->getOpcode() != TargetOpcode::G_LOAD)
return false;

// Check if the load is unaligned relative to the vector's total size
if (LoadMI->memoperands_empty())
return false;

const MachineMemOperand *MMO = LoadMI->memoperands().front();
const LLT LoadVecTy = MRI.getType(LoadVecReg);
const unsigned LoadVecSizeInBytes = LoadVecTy.getSizeInBytes();
// Vector is unaligned if alignment < vector size
// This allows extracting elements when the vector load itself is unaligned
if (MMO->getAlign().value() >= LoadVecSizeInBytes)
return false;

// Check that the loaded vector is only used by extracts (through bitcast and
// pad). The helper function will automatically traverse through bitcasts.
const Register LoadDstReg = LoadMI->getOperand(0).getReg();

if (!areLoadUsesValidForExtractCombine(LoadDstReg, ZExtExtractOpcode,
SExtExtractOpcode, PadVectorOpcode,
MRI))
return false;

// All checks passed, we can combine
MatchInfo = [=, &ExtractMI, &MRI, &Observer](MachineIRBuilder &B) {
const Register PtrReg = LoadMI->getOperand(1).getReg();
const LLT S20 = LLT::scalar(20);

const unsigned ElemSizeInBytes = ElemSize / 8;
const int64_t ByteOffset = Index * ElemSizeInBytes;

// Set insertion point right after the original vector load
B.setInsertPt(*LoadMI->getParent(), std::next(LoadMI->getIterator()));
B.setDebugLoc(LoadMI->getDebugLoc());

// Create offset constant and pointer add
const Register OffsetReg = B.buildConstant(S20, ByteOffset).getReg(0);
const Register NewPtrReg =
B.buildPtrAdd(MRI.getType(PtrReg), PtrReg, OffsetReg).getReg(0);

// Calculate alignment for scalar load based on original vector load
// alignment using GCD to find the maximum provable alignment
const unsigned OrigAlign = MMO->getAlign().value();
const unsigned ScalarAlign = std::gcd(OrigAlign, OrigAlign + ByteOffset);

// Create new scalar load with derived alignment
MachineFunction &MF = B.getMF();
MachineMemOperand *NewMMO =
MF.getMachineMemOperand(MMO->getPointerInfo(), MMO->getFlags(),
ElemSizeInBytes, Align(ScalarAlign));

const Register LoadResultReg = MRI.createGenericVirtualRegister(ElemTy);
Observer.createdInstr(*B.buildLoad(LoadResultReg, NewPtrReg, *NewMMO));

// Now set insertion point at the extract position for the copy/extension
B.setInstr(ExtractMI);

// Handle the result based on the original opcode
const Register DstReg = ExtractMI.getOperand(0).getReg();
if (IsZExtExtract) {
// Need to zero-extend to s32
Observer.createdInstr(*B.buildZExt(DstReg, LoadResultReg));
} else if (IsSExtExtract) {
// Need to sign-extend to s32
Observer.createdInstr(*B.buildSExt(DstReg, LoadResultReg));
} else {
// G_EXTRACT_VECTOR_ELT
// Just copy the result
Observer.createdInstr(*B.buildCopy(DstReg, LoadResultReg));
}

Observer.erasingInstr(ExtractMI);
ExtractMI.eraseFromParent();
};

return true;
}

/// Match unaligned vector loads and transform them to use a better-aligned
/// element type based on the actual alignment.
/// Pattern:
/// %vec:_(<32 x s16>) = G_LOAD %ptr(p0) :: (align 4)
/// Converts to:
/// %vec_new:_(<16 x s32>) = G_LOAD %ptr(p0) :: (align 4)
/// %vec:_(<32 x s16>) = G_BITCAST %vec_new(<16 x s32>)
bool llvm::matchUnalignedVectorLoad(MachineInstr &LoadMI,
MachineRegisterInfo &MRI,
GISelChangeObserver &Observer,
BuildFnTy &MatchInfo) {
assert(LoadMI.getOpcode() == TargetOpcode::G_LOAD && "Expected G_LOAD");

// Get load information
const Register DstReg = LoadMI.getOperand(0).getReg();
const LLT DstTy = MRI.getType(DstReg);

// Only process vector loads
if (!DstTy.isVector())
return false;

// Check memory operand for alignment
if (LoadMI.memoperands_empty())
return false;

const MachineMemOperand *MMO = LoadMI.memoperands().front();
const unsigned Alignment = MMO->getAlign().value();

// Skip if the vector is already well-aligned (alignment >= vector size)
const unsigned VecSizeInBytes = DstTy.getSizeInBytes();
if (Alignment >= VecSizeInBytes)
return false;

// Get element type information
const LLT ElemTy = DstTy.getElementType();
const unsigned ElemSizeInBits = ElemTy.getSizeInBits();

// Skip if the load is only used for extracts - let matchUnalignedExtractLoad
// handle it. This prevents the two combiners from competing for the same
// opportunities
const MachineFunction &MF = *LoadMI.getMF();
const AIEBaseInstrInfo &TII =
*static_cast<const AIEBaseInstrInfo *>(MF.getSubtarget().getInstrInfo());
const unsigned ZExtExtractOpcode =
TII.getGenericExtractVectorEltOpcode(false);
const unsigned SExtExtractOpcode = TII.getGenericExtractVectorEltOpcode(true);
const unsigned PadVectorOpcode = TII.getGenericPadVectorOpcode();

if (areLoadUsesValidForExtractCombine(
DstReg, ZExtExtractOpcode, SExtExtractOpcode, PadVectorOpcode, MRI))
return false;

// Skip if the load has a single user that is a G_STORE with the same
// alignment. This case can be perfectly scalarized during legalization
if (MRI.hasOneNonDBGUse(DstReg)) {
const MachineInstr *UserMI = &*MRI.use_instr_nodbg_begin(DstReg);
if (UserMI->getOpcode() == TargetOpcode::G_STORE) {
const GStore *StoreMI = cast<GStore>(UserMI);
if (!StoreMI->memoperands_empty()) {
const MachineMemOperand *StoreMMO = StoreMI->memoperands().front();
// If store has the same alignment as the load, skip
if (StoreMMO->getAlign().value() == Alignment)
return false;
}
}
}

// We already have the best element size option.
if (Alignment == ElemSizeInBits / 8)
return false;

// Only handle s8 and s16 element types that can be promoted to s32
if (ElemSizeInBits != 8 && ElemSizeInBits != 16)
return false;

// Determine the optimal element type based on alignment
unsigned NewElemSizeInBits = 0;
if (Alignment >= 4) {
NewElemSizeInBits = 32;
} else if (Alignment >= 2) {
NewElemSizeInBits = 16;
} else {
// Alignment doesn't allow for a better element type
return false;
}

// Check if the vector size is compatible with the new element size
const unsigned VecSizeInBits = DstTy.getSizeInBits();
if (VecSizeInBits % NewElemSizeInBits != 0)
return false;

MatchInfo = [=, PtrReg = LoadMI.getOperand(1).getReg(), &MRI,
&Observer](MachineIRBuilder &B) {
MachineFunction &MF = B.getMF();

// Calculate new number of elements
const unsigned NewNumElems = VecSizeInBits / NewElemSizeInBits;

// Create the new vector type with better-aligned elements
const LLT NewVecTy = LLT::fixed_vector(NewNumElems, NewElemSizeInBits);
const Register NewLoadReg = MRI.createGenericVirtualRegister(NewVecTy);

// Create a new MMO with the same properties but updated type
MachineMemOperand *NewMMO = MF.getMachineMemOperand(
MMO->getPointerInfo(), MMO->getFlags(), NewVecTy, MMO->getAlign());

Observer.createdInstr(*B.buildLoad(NewLoadReg, PtrReg, *NewMMO));

// Bitcast back to the original type
Observer.createdInstr(*B.buildBitcast(DstReg, NewLoadReg));
};

return true;
}
9 changes: 9 additions & 0 deletions llvm/lib/Target/AIE/AIECombinerHelper.h
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,15 @@ bool matchExtractVecEltAssertBcst(MachineInstr &MI, MachineRegisterInfo &MRI,
const AIEBaseInstrInfo &TII,
GISelChangeObserver &Observer,
BuildFnTy &MatchInfo);

bool matchUnalignedExtractLoad(MachineInstr &ExtractMI,
MachineRegisterInfo &MRI,
GISelChangeObserver &Observer,
BuildFnTy &MatchInfo);

bool matchUnalignedVectorLoad(MachineInstr &LoadMI, MachineRegisterInfo &MRI,
GISelChangeObserver &Observer,
BuildFnTy &MatchInfo);
} // namespace llvm

#endif
Loading