Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 106 additions & 28 deletions llvm/lib/Target/AMDGPU/AMDGPULowerBufferFatPointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@
#include "llvm/Support/AtomicOrdering.h"
#include "llvm/Support/Debug.h"
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/MathExtras.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "llvm/Transforms/Utils/Local.h"
#include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
Expand Down Expand Up @@ -688,6 +689,10 @@ class LegalizeBufferContentTypesVisitor

const DataLayout &DL;

// Subtarget info, needed for determining what cache control bits to set.
const TargetMachine *TM;
const GCNSubtarget *ST = nullptr;

/// If T is [N x U], where U is a scalar type, return the vector type
/// <N x U>, otherwise, return T.
Type *scalarArrayTypeAsVector(Type *MaybeArrayType);
Expand All @@ -696,10 +701,32 @@ class LegalizeBufferContentTypesVisitor

/// Break up the loads of a struct into the loads of its components

/// Return the maximum allowed load/store width for the given type and
/// alignment combination based on subtarget flags.
/// 1. If unaligned accesses are not enabled, then any load/store that is less
/// than word-aligned has to be handled one byte or ushort at a time.
/// 2. If relaxed OOB mode is not set, we must ensure that the in-bounds
/// part of a partially out of bounds read/write is performed correctly. This
/// means that any load that isn't naturally aligned has to be split into
/// parts that are naturally aligned, so that, after bitcasting, we don't have
/// unaligned loads that could discard valid data.
///
/// For example, if we're loading a <8 x i8>, that's actually a load of a <2 x
/// i32>, and if we load from an align(2) address, that address might be 2
/// bytes from the end of the buffer. The hardware will, when performing the
/// <2 x i32> load, mask off the entire first word, causing the two in-bounds
/// bytes to be masked off.
///
/// Unlike the complete disablement of unaligned accesses from point 1,
/// this does not apply to unaligned scalars, but will apply to cases like
/// `load <2 x i32>, align 4` since the left elemenvt might be out of bounds.
uint64_t maxIntrinsicWidth(Type *Ty, Align A);

/// Convert a vector or scalar type that can't be operated on by buffer
/// intrinsics to one that would be legal through bitcasts and/or truncation.
/// Uses the wider of i32, i16, or i8 where possible.
Type *legalNonAggregateFor(Type *T);
/// Uses the wider of i32, i16, or i8 where possible, clamping to the maximum
/// allowed width under the alignment rules and subtarget flags.
Type *legalNonAggregateForMemOp(Type *T, uint64_t MaxWidth);
Value *makeLegalNonAggregate(Value *V, Type *TargetType, const Twine &Name);
Value *makeIllegalNonAggregate(Value *V, Type *OrigType, const Twine &Name);

Expand All @@ -713,8 +740,9 @@ class LegalizeBufferContentTypesVisitor
/// Return the [index, length] pairs into which `T` needs to be cut to form
/// legal buffer load or store operations. Clears `Slices`. Creates an empty
/// `Slices` for non-vector inputs and creates one slice if no slicing will be
/// needed.
void getVecSlices(Type *T, SmallVectorImpl<VecSlice> &Slices);
/// needed. No slice may be larger than `MaxWidth`.
void getVecSlices(Type *T, uint64_t MaxWidth,
SmallVectorImpl<VecSlice> &Slices);

Value *extractSlice(Value *Vec, VecSlice S, const Twine &Name);
Value *insertSlice(Value *Whole, Value *Part, VecSlice S, const Twine &Name);
Expand Down Expand Up @@ -743,8 +771,9 @@ class LegalizeBufferContentTypesVisitor
bool visitStoreInst(StoreInst &SI);

public:
LegalizeBufferContentTypesVisitor(const DataLayout &DL, LLVMContext &Ctx)
: IRB(Ctx, InstSimplifyFolder(DL)), DL(DL) {}
LegalizeBufferContentTypesVisitor(const DataLayout &DL, LLVMContext &Ctx,
const TargetMachine *TM)
: IRB(Ctx, InstSimplifyFolder(DL)), DL(DL), TM(TM) {}
bool processFunction(Function &F);
};
} // namespace
Expand Down Expand Up @@ -791,7 +820,48 @@ Value *LegalizeBufferContentTypesVisitor::vectorToArray(Value *V,
return ArrayRes;
}

Type *LegalizeBufferContentTypesVisitor::legalNonAggregateFor(Type *T) {
uint64_t LegalizeBufferContentTypesVisitor::maxIntrinsicWidth(Type *T,
Align A) {
Align Result(16);
if (!ST->hasUnalignedBufferAccessEnabled() && A < Align(4))
Result = A;
auto *VT = dyn_cast<VectorType>(T);
if (!ST->hasRelaxedBufferOOBMode() && VT) {
TypeSize ElemBits = DL.getTypeSizeInBits(VT->getElementType());
if (ElemBits.isKnownMultipleOf(32)) {
// Word-sized operations are bounds-checked per word. So, the only case we
// have to worry about is stores that start out of bounds and then go in,
// and those can only become in-bounds on a multiple of their alignment.
// Therefore, we can use the declared alignment of the operation as the
// maximum width, rounding up to 4.
Result = std::min(Result, std::max(A, Align(4)));
} else if (ElemBits.isKnownMultipleOf(8) ||
isPowerOf2_64(ElemBits.getKnownMinValue())) {
// To ensure correct behavior for sub-word types, we must always scalarize
// unaligned loads of sub-word types. For example, if you load
// a <4 x i8> from offset 7 in an 8-byte buffer, expecting the vector
// to be padded out with 0s after that last byte, you'll get all 0s
// instead. To prevent this behavior when not requested, de-vectorize such
// loads.
//
// This condition could be looser and mirror the word-length condition
// if we were allowed to assume that the number of records in a buffer
// was a multiple of 4 - then, we could always use the vector's
// alignment of the access on the assumption that no one wants their
// mask to kick in mid-word.
//
// Strict OOB checking isn't supported if the size of each element is a
// non-power-of-2 value less than 8, since there's no feasible way to
// apply such a strict bounds check.
Result =
commonAlignment(Result, divideCeil(ElemBits.getKnownMinValue(), 8));
}
}
return Result.value() * 8;
}

Type *LegalizeBufferContentTypesVisitor::legalNonAggregateForMemOp(
Type *T, uint64_t MaxWidth) {
TypeSize Size = DL.getTypeStoreSizeInBits(T);
// Implicitly zero-extend to the next byte if needed
if (!DL.typeSizeEqualsStoreSize(T))
Expand All @@ -803,15 +873,16 @@ Type *LegalizeBufferContentTypesVisitor::legalNonAggregateFor(Type *T) {
return T;
}
unsigned ElemSize = DL.getTypeSizeInBits(ElemTy).getFixedValue();
if (isPowerOf2_32(ElemSize) && ElemSize >= 16 && ElemSize <= 128) {
if (isPowerOf2_32(ElemSize) && ElemSize >= 16 && ElemSize <= MaxWidth) {
// [vectors of] anything that's 16/32/64/128 bits can be cast and split into
// legal buffer operations.
// legal buffer operations, except that we might need to cut them into
// smaller values if we're not allowed to do unaligned vector loads.
return T;
}
Type *BestVectorElemType = nullptr;
if (Size.isKnownMultipleOf(32))
if (Size.isKnownMultipleOf(32) && MaxWidth >= 32)
BestVectorElemType = IRB.getInt32Ty();
else if (Size.isKnownMultipleOf(16))
else if (Size.isKnownMultipleOf(16) && MaxWidth >= 16)
BestVectorElemType = IRB.getInt16Ty();
else
BestVectorElemType = IRB.getInt8Ty();
Expand Down Expand Up @@ -884,7 +955,7 @@ Type *LegalizeBufferContentTypesVisitor::intrinsicTypeFor(Type *LegalType) {
}

void LegalizeBufferContentTypesVisitor::getVecSlices(
Type *T, SmallVectorImpl<VecSlice> &Slices) {
Type *T, uint64_t MaxWidth, SmallVectorImpl<VecSlice> &Slices) {
Slices.clear();
auto *VT = dyn_cast<FixedVectorType>(T);
if (!VT)
Expand All @@ -905,8 +976,8 @@ void LegalizeBufferContentTypesVisitor::getVecSlices(

uint64_t TotalElems = VT->getNumElements();
uint64_t Index = 0;
auto TrySlice = [&](unsigned MaybeLen) {
if (MaybeLen > 0 && Index + MaybeLen <= TotalElems) {
auto TrySlice = [&](unsigned MaybeLen, unsigned Width) {
if (MaybeLen > 0 && Width <= MaxWidth && Index + MaybeLen <= TotalElems) {
VecSlice Slice{/*Index=*/Index, /*Length=*/MaybeLen};
Slices.push_back(Slice);
Index += MaybeLen;
Expand All @@ -915,9 +986,9 @@ void LegalizeBufferContentTypesVisitor::getVecSlices(
return false;
};
while (Index < TotalElems) {
TrySlice(ElemsPer4Words) || TrySlice(ElemsPer3Words) ||
TrySlice(ElemsPer2Words) || TrySlice(ElemsPerWord) ||
TrySlice(ElemsPerShort) || TrySlice(ElemsPerByte);
TrySlice(ElemsPer4Words, 128) || TrySlice(ElemsPer3Words, 96) ||
TrySlice(ElemsPer2Words, 64) || TrySlice(ElemsPerWord, 32) ||
TrySlice(ElemsPerShort, 16) || TrySlice(ElemsPerByte, 8);
}
}

Expand Down Expand Up @@ -1004,11 +1075,13 @@ bool LegalizeBufferContentTypesVisitor::visitLoadImpl(

// Typical case

Align PartAlign = commonAlignment(OrigLI.getAlign(), AggByteOff);
Type *ArrayAsVecType = scalarArrayTypeAsVector(PartType);
Type *LegalType = legalNonAggregateFor(ArrayAsVecType);
uint64_t MaxWidth = maxIntrinsicWidth(ArrayAsVecType, PartAlign);
Type *LegalType = legalNonAggregateForMemOp(ArrayAsVecType, MaxWidth);

SmallVector<VecSlice> Slices;
getVecSlices(LegalType, Slices);
getVecSlices(LegalType, MaxWidth, Slices);
bool HasSlices = Slices.size() > 1;
bool IsAggPart = !AggIdxs.empty();
Value *LoadsRes;
Expand Down Expand Up @@ -1045,7 +1118,8 @@ bool LegalizeBufferContentTypesVisitor::visitLoadImpl(
Value *NewPtr = IRB.CreateGEP(
IRB.getInt8Ty(), OrigLI.getPointerOperand(), IRB.getInt32(ByteOffset),
OrigPtr->getName() + ".off.ptr." + Twine(ByteOffset),
GEPNoWrapFlags::noUnsignedWrap());
ST->hasRelaxedBufferOOBMode() ? GEPNoWrapFlags::noUnsignedWrap()
: GEPNoWrapFlags::none());
Type *LoadableType = intrinsicTypeFor(SliceType);
LoadInst *NewLI = IRB.CreateAlignedLoad(
LoadableType, NewPtr, commonAlignment(OrigLI.getAlign(), ByteOffset),
Expand Down Expand Up @@ -1134,13 +1208,15 @@ std::pair<bool, bool> LegalizeBufferContentTypesVisitor::visitStoreImpl(
NewData = arrayToVector(NewData, ArrayAsVecType, Name);
}

Type *LegalType = legalNonAggregateFor(ArrayAsVecType);
Align PartAlign = commonAlignment(OrigSI.getAlign(), AggByteOff);
uint64_t MaxWidth = maxIntrinsicWidth(ArrayAsVecType, PartAlign);
Type *LegalType = legalNonAggregateForMemOp(ArrayAsVecType, MaxWidth);
if (LegalType != ArrayAsVecType) {
NewData = makeLegalNonAggregate(NewData, LegalType, Name);
}

SmallVector<VecSlice> Slices;
getVecSlices(LegalType, Slices);
getVecSlices(LegalType, MaxWidth, Slices);
bool NeedToSplit = Slices.size() > 1 || IsAggPart;
if (!NeedToSplit) {
Type *StorableType = intrinsicTypeFor(LegalType);
Expand All @@ -1161,10 +1237,11 @@ std::pair<bool, bool> LegalizeBufferContentTypesVisitor::visitStoreImpl(
Type *SliceType =
S.Length != 1 ? FixedVectorType::get(ElemType, S.Length) : ElemType;
int64_t ByteOffset = AggByteOff + S.Index * ElemBytes;
Value *NewPtr =
IRB.CreateGEP(IRB.getInt8Ty(), OrigPtr, IRB.getInt32(ByteOffset),
OrigPtr->getName() + ".part." + Twine(S.Index),
GEPNoWrapFlags::noUnsignedWrap());
Value *NewPtr = IRB.CreateGEP(
IRB.getInt8Ty(), OrigPtr, IRB.getInt32(ByteOffset),
OrigPtr->getName() + ".part." + Twine(S.Index),
ST->hasRelaxedBufferOOBMode() ? GEPNoWrapFlags::noUnsignedWrap()
: GEPNoWrapFlags::none());
Value *DataSlice = extractSlice(NewData, S, Name);
Type *StorableType = intrinsicTypeFor(SliceType);
DataSlice = IRB.CreateBitCast(DataSlice, StorableType,
Expand Down Expand Up @@ -1193,6 +1270,7 @@ bool LegalizeBufferContentTypesVisitor::visitStoreInst(StoreInst &SI) {
}

bool LegalizeBufferContentTypesVisitor::processFunction(Function &F) {
ST = &TM->getSubtarget<GCNSubtarget>(F);
bool Changed = false;
// Note, memory transfer intrinsics won't
for (Instruction &I : make_early_inc_range(instructions(F))) {
Expand Down Expand Up @@ -2438,8 +2516,8 @@ bool AMDGPULowerBufferFatPointers::run(Module &M, const TargetMachine &TM) {

StoreFatPtrsAsIntsAndExpandMemcpyVisitor MemOpsRewrite(&IntTM, DL,
M.getContext(), &TM);
LegalizeBufferContentTypesVisitor BufferContentsTypeRewrite(DL,
M.getContext());
LegalizeBufferContentTypesVisitor BufferContentsTypeRewrite(
DL, M.getContext(), &TM);
for (Function &F : M.functions()) {
bool InterfaceChange = hasFatPointerInterface(F, &StructTM);
bool BodyChanges = containsBufferFatPointers(F, &StructTM);
Expand Down
Loading