Skip to content

Commit 030d8cb

Browse files
committed
wip: Update DXILCBufferAccess to padding approach
1 parent 5797e18 commit 030d8cb

File tree

14 files changed

+179
-423
lines changed

14 files changed

+179
-423
lines changed

clang/lib/CodeGen/TargetInfo.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,20 @@ class TargetCodeGenInfo {
448448
return nullptr;
449449
}
450450

451+
virtual llvm::Type *
452+
getHLSLPadding(CodeGenModule &CGM, CharUnits NumBytes) const {
453+
return llvm::ArrayType::get(llvm::Type::getInt8Ty(CGM.getLLVMContext()),
454+
NumBytes.getQuantity());
455+
}
456+
457+
virtual bool isHLSLPadding(llvm::Type *Ty) const {
458+
// TODO: Do we actually want to default these functions like this?
459+
if (auto *AT = dyn_cast<llvm::ArrayType>(Ty))
460+
if (AT->getElementType() == llvm::Type::getInt8Ty(Ty->getContext()))
461+
return true;
462+
return false;
463+
}
464+
451465
// Set the Branch Protection Attributes of the Function accordingly to the
452466
// BPI. Remove attributes that contradict with current BPI.
453467
static void

clang/lib/CodeGen/Targets/DirectX.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,19 @@ class DirectXTargetCodeGenInfo : public TargetCodeGenInfo {
3232
llvm::Type *
3333
getHLSLType(CodeGenModule &CGM, const Type *T,
3434
const SmallVector<int32_t> *Packoffsets = nullptr) const override;
35+
36+
llvm::Type *getHLSLPadding(CodeGenModule &CGM,
37+
CharUnits NumBytes) const override {
38+
unsigned Size = NumBytes.getQuantity();
39+
return llvm::TargetExtType::get(CGM.getLLVMContext(), "dx.Padding", {},
40+
{Size});
41+
}
42+
43+
bool isHLSLPadding(llvm::Type *Ty) const override {
44+
if (auto *TET = dyn_cast<llvm::TargetExtType>(Ty))
45+
return TET->getName() == "dx.Padding";
46+
return false;
47+
}
3548
};
3649

3750
llvm::Type *DirectXTargetCodeGenInfo::getHLSLType(

clang/lib/CodeGen/Targets/SPIR.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,20 @@ class CommonSPIRTargetCodeGenInfo : public TargetCodeGenInfo {
5656
llvm::Type *
5757
getHLSLType(CodeGenModule &CGM, const Type *Ty,
5858
const SmallVector<int32_t> *Packoffsets = nullptr) const override;
59+
60+
llvm::Type *
61+
getHLSLPadding(CodeGenModule &CGM, CharUnits NumBytes) const override {
62+
unsigned Size = NumBytes.getQuantity();
63+
return llvm::TargetExtType::get(CGM.getLLVMContext(), "spirv.Padding", {},
64+
{Size});
65+
}
66+
67+
bool isHLSLPadding(llvm::Type *Ty) const override {
68+
if (auto *TET = dyn_cast<llvm::TargetExtType>(Ty))
69+
return TET->getName() == "spirv.Padding";
70+
return false;
71+
}
72+
5973
llvm::Type *getSPIRVImageTypeFromHLSLResource(
6074
const HLSLAttributedResourceType::Attributes &attributes,
6175
QualType SampledType, CodeGenModule &CGM) const;

llvm/include/llvm/Frontend/HLSL/CBuffer.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,8 @@ class CBufferMetadata {
4646
CBufferMetadata(NamedMDNode *MD) : MD(MD) {}
4747

4848
public:
49-
static std::optional<CBufferMetadata> get(Module &M);
49+
static std::optional<CBufferMetadata>
50+
get(Module &M, llvm::function_ref<bool(Type *)> IsPadding);
5051

5152
using iterator = SmallVector<CBufferMapping>::iterator;
5253
iterator begin() { return Mappings.begin(); }
@@ -55,9 +56,6 @@ class CBufferMetadata {
5556
void eraseFromModule();
5657
};
5758

58-
APInt translateCBufArrayOffset(const DataLayout &DL, APInt Offset,
59-
ArrayType *Ty);
60-
6159
} // namespace hlsl
6260
} // namespace llvm
6361

llvm/lib/Frontend/HLSL/CBuffer.cpp

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -15,25 +15,28 @@
1515
using namespace llvm;
1616
using namespace llvm::hlsl;
1717

18-
static size_t getMemberOffset(GlobalVariable *Handle, size_t Index) {
18+
static SmallVector<size_t>
19+
getMemberOffsets(const DataLayout &DL, GlobalVariable *Handle,
20+
llvm::function_ref<bool(Type *)> IsPadding) {
21+
SmallVector<size_t> Offsets;
22+
1923
auto *HandleTy = cast<TargetExtType>(Handle->getValueType());
2024
assert((HandleTy->getName().ends_with(".CBuffer") ||
2125
HandleTy->getName() == "spirv.VulkanBuffer") &&
2226
"Not a cbuffer type");
2327
assert(HandleTy->getNumTypeParameters() == 1 && "Expected layout type");
28+
auto *LayoutTy = cast<StructType>(HandleTy->getTypeParameter(0));
2429

25-
auto *LayoutTy = cast<TargetExtType>(HandleTy->getTypeParameter(0));
26-
assert(LayoutTy->getName().ends_with(".Layout") && "Not a layout type");
27-
28-
// Skip the "size" parameter.
29-
size_t ParamIndex = Index + 1;
30-
assert(LayoutTy->getNumIntParameters() > ParamIndex &&
31-
"Not enough parameters");
30+
const StructLayout *SL = DL.getStructLayout(LayoutTy);
31+
for (int I = 0, E = LayoutTy->getNumElements(); I < E; ++I)
32+
if (!IsPadding(LayoutTy->getElementType(I)))
33+
Offsets.push_back(SL->getElementOffset(I));
3234

33-
return LayoutTy->getIntParameter(ParamIndex);
35+
return Offsets;
3436
}
3537

36-
std::optional<CBufferMetadata> CBufferMetadata::get(Module &M) {
38+
std::optional<CBufferMetadata>
39+
CBufferMetadata::get(Module &M, llvm::function_ref<bool(Type *)> IsPadding) {
3740
NamedMDNode *CBufMD = M.getNamedMetadata("hlsl.cbs");
3841
if (!CBufMD)
3942
return std::nullopt;
@@ -47,13 +50,16 @@ std::optional<CBufferMetadata> CBufferMetadata::get(Module &M) {
4750
cast<ValueAsMetadata>(MD->getOperand(0))->getValue());
4851
CBufferMapping &Mapping = Result->Mappings.emplace_back(Handle);
4952

53+
SmallVector<size_t> MemberOffsets =
54+
getMemberOffsets(M.getDataLayout(), Handle, IsPadding);
55+
5056
for (int I = 1, E = MD->getNumOperands(); I < E; ++I) {
5157
Metadata *OpMD = MD->getOperand(I);
5258
// Some members may be null if they've been optimized out.
5359
if (!OpMD)
5460
continue;
5561
auto *V = cast<GlobalVariable>(cast<ValueAsMetadata>(OpMD)->getValue());
56-
Mapping.Members.emplace_back(V, getMemberOffset(Handle, I - 1));
62+
Mapping.Members.emplace_back(V, MemberOffsets[I - 1]);
5763
}
5864
}
5965

@@ -64,10 +70,3 @@ void CBufferMetadata::eraseFromModule() {
6470
// Remove the cbs named metadata
6571
MD->eraseFromParent();
6672
}
67-
68-
APInt hlsl::translateCBufArrayOffset(const DataLayout &DL, APInt Offset,
69-
ArrayType *Ty) {
70-
int64_t TypeSize = DL.getTypeSizeInBits(Ty->getElementType()) / 8;
71-
int64_t RoundUp = alignTo(TypeSize, Align(CBufferRowSizeInBytes));
72-
return Offset.udiv(TypeSize) * RoundUp;
73-
}

llvm/lib/IR/Type.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,10 @@ static TargetTypeInfo getTargetTypeInfo(const TargetExtType *Ty) {
10081008
}
10091009
if (Name == "spirv.IntegralConstant" || Name == "spirv.Literal")
10101010
return TargetTypeInfo(Type::getVoidTy(C));
1011+
if (Name == "spirv.Padding")
1012+
return TargetTypeInfo(
1013+
ArrayType::get(Type::getInt8Ty(C), Ty->getIntParameter(0)),
1014+
TargetExtType::CanBeGlobal);
10111015
if (Name.starts_with("spirv."))
10121016
return TargetTypeInfo(PointerType::get(C, 0), TargetExtType::HasZeroInit,
10131017
TargetExtType::CanBeGlobal,

llvm/lib/Target/DirectX/DXILCBufferAccess.cpp

Lines changed: 5 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include "DXILCBufferAccess.h"
1010
#include "DirectX.h"
11+
#include "llvm/Analysis/DXILResource.h"
1112
#include "llvm/Frontend/HLSL/CBuffer.h"
1213
#include "llvm/Frontend/HLSL/HLSLResource.h"
1314
#include "llvm/IR/IRBuilder.h"
@@ -97,10 +98,6 @@ struct CBufferResource {
9798
(void)Success;
9899
assert(Success && "Offsets into cbuffer globals must be constant");
99100

100-
if (auto *ATy = dyn_cast<ArrayType>(Member->getValueType()))
101-
ConstantOffset =
102-
hlsl::translateCBufArrayOffset(DL, ConstantOffset, ATy);
103-
104101
return ConstantOffset.getZExtValue();
105102
}
106103

@@ -194,102 +191,21 @@ static void replaceLoad(LoadInst *LI, CBufferResource &CBR,
194191
DeadInsts.push_back(LI);
195192
}
196193

197-
/// This function recursively copies N array elements from the cbuffer resource
198-
/// CBR to the MemCpy Destination. Recursion is used to unravel multidimensional
199-
/// arrays into a sequence of scalar/vector extracts and stores.
200-
static void copyArrayElemsForMemCpy(IRBuilder<> &Builder, MemCpyInst *MCI,
201-
CBufferResource &CBR, ArrayType *ArrTy,
202-
size_t ArrOffset, size_t N,
203-
const Twine &Name = "") {
204-
const DataLayout &DL = MCI->getDataLayout();
205-
Type *ElemTy = ArrTy->getElementType();
206-
size_t ElemTySize = DL.getTypeAllocSize(ElemTy);
207-
for (unsigned I = 0; I < N; ++I) {
208-
size_t Offset = ArrOffset + I * ElemTySize;
209-
210-
// Recursively copy nested arrays
211-
if (ArrayType *ElemArrTy = dyn_cast<ArrayType>(ElemTy)) {
212-
copyArrayElemsForMemCpy(Builder, MCI, CBR, ElemArrTy, Offset,
213-
ElemArrTy->getNumElements(), Name);
214-
continue;
215-
}
216-
217-
// Load CBuffer value and store it in Dest
218-
APInt CBufArrayOffset(
219-
DL.getIndexTypeSizeInBits(MCI->getSource()->getType()), Offset);
220-
CBufArrayOffset =
221-
hlsl::translateCBufArrayOffset(DL, CBufArrayOffset, ArrTy);
222-
Value *CBufferVal =
223-
CBR.loadValue(Builder, ElemTy, CBufArrayOffset.getZExtValue(), Name);
224-
Value *GEP =
225-
Builder.CreateInBoundsGEP(Builder.getInt8Ty(), MCI->getDest(),
226-
{Builder.getInt32(Offset)}, Name + ".dest");
227-
Builder.CreateStore(CBufferVal, GEP, MCI->isVolatile());
228-
}
229-
}
230-
231-
/// Replace memcpy from a cbuffer global with a memcpy from the cbuffer handle
232-
/// itself. Assumes the cbuffer global is an array, and the length of bytes to
233-
/// copy is divisible by array element allocation size.
234-
/// The memcpy source must also be a direct cbuffer global reference, not a GEP.
235-
static void replaceMemCpy(MemCpyInst *MCI, CBufferResource &CBR) {
236-
237-
ArrayType *ArrTy = dyn_cast<ArrayType>(CBR.getValueType());
238-
assert(ArrTy && "MemCpy lowering is only supported for array types");
239-
240-
// This assumption vastly simplifies the implementation
241-
if (MCI->getSource() != CBR.Member)
242-
reportFatalUsageError(
243-
"Expected MemCpy source to be a cbuffer global variable");
244-
245-
ConstantInt *Length = dyn_cast<ConstantInt>(MCI->getLength());
246-
uint64_t ByteLength = Length->getZExtValue();
247-
248-
// If length to copy is zero, no memcpy is needed
249-
if (ByteLength == 0) {
250-
MCI->eraseFromParent();
251-
return;
252-
}
253-
254-
const DataLayout &DL = CBR.getDataLayout();
255-
256-
Type *ElemTy = ArrTy->getElementType();
257-
size_t ElemSize = DL.getTypeAllocSize(ElemTy);
258-
assert(ByteLength % ElemSize == 0 &&
259-
"Length of bytes to MemCpy must be divisible by allocation size of "
260-
"source/destination array elements");
261-
size_t ElemsToCpy = ByteLength / ElemSize;
262-
263-
IRBuilder<> Builder(MCI);
264-
CBR.createAndSetCurrentHandle(Builder);
265-
266-
copyArrayElemsForMemCpy(Builder, MCI, CBR, ArrTy, 0, ElemsToCpy,
267-
"memcpy." + MCI->getDest()->getName() + "." +
268-
MCI->getSource()->getName());
269-
270-
MCI->eraseFromParent();
271-
}
272-
273194
static void replaceAccessesWithHandle(CBufferResource &CBR) {
274195
SmallVector<WeakTrackingVH> DeadInsts;
275196

276197
SmallVector<User *> ToProcess{CBR.users()};
277198
while (!ToProcess.empty()) {
278199
User *Cur = ToProcess.pop_back_val();
200+
assert(!isa<MemCpyInst>(Cur) &&
201+
"memcpy should have been removed in an earlier pass");
279202

280203
// If we have a load instruction, replace the access.
281204
if (auto *LI = dyn_cast<LoadInst>(Cur)) {
282205
replaceLoad(LI, CBR, DeadInsts);
283206
continue;
284207
}
285208

286-
// If we have a memcpy instruction, replace it with multiple accesses and
287-
// subsequent stores to the destination
288-
if (auto *MCI = dyn_cast<MemCpyInst>(Cur)) {
289-
replaceMemCpy(MCI, CBR);
290-
continue;
291-
}
292-
293209
// Otherwise, walk users looking for a load...
294210
if (isa<GetElementPtrInst>(Cur) || isa<GEPOperator>(Cur)) {
295211
ToProcess.append(Cur->user_begin(), Cur->user_end());
@@ -302,7 +218,8 @@ static void replaceAccessesWithHandle(CBufferResource &CBR) {
302218
}
303219

304220
static bool replaceCBufferAccesses(Module &M) {
305-
std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M);
221+
std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(
222+
M, [](Type *Ty) { return isa<llvm::dxil::PaddingExtType>(Ty); });
306223
if (!CBufMD)
307224
return false;
308225

llvm/test/CodeGen/DirectX/CBufferAccess/array-typedgep.ll

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,24 +3,24 @@
33
; cbuffer CB : register(b0) {
44
; float a1[3];
55
; }
6-
%__cblayout_CB = type <{ [3 x float] }>
6+
%__cblayout_CB = type <{ [2 x <{ float, [12 x i8] }>], float }>
77

8-
@CB.cb = global target("dx.CBuffer", target("dx.Layout", %__cblayout_CB, 36, 0)) poison
8+
@CB.cb = global target("dx.CBuffer", %__cblayout_CB) poison
99
; CHECK: @CB.cb =
1010
; CHECK-NOT: external {{.*}} addrspace(2) global
11-
@a1 = external addrspace(2) global [3 x float], align 4
11+
@a1 = external addrspace(2) global <{ [2 x <{ float, [12 x i8] }>], float }>, align 4
1212

1313
; CHECK: define void @f
1414
define void @f(ptr %dst) {
1515
entry:
16-
%CB.cb_h = call target("dx.CBuffer", target("dx.Layout", %__cblayout_CB, 36, 0)) @llvm.dx.resource.handlefrombinding.tdx.CBuffer_tdx.Layout_s___cblayout_CBs_36_0tt(i32 0, i32 0, i32 1, i32 0, ptr null)
17-
store target("dx.CBuffer", target("dx.Layout", %__cblayout_CB, 36, 0)) %CB.cb_h, ptr @CB.cb, align 4
16+
%CB.cb_h = call target("dx.CBuffer", %__cblayout_CB) @llvm.dx.resource.handlefrombinding(i32 0, i32 0, i32 1, i32 0, ptr null)
17+
store target("dx.CBuffer", %__cblayout_CB) %CB.cb_h, ptr @CB.cb, align 4
1818

19-
; CHECK: [[CB:%.*]] = load target("dx.CBuffer", {{.*}})), ptr @CB.cb
20-
; CHECK: [[LOAD:%.*]] = call { float, float, float, float } @llvm.dx.resource.load.cbufferrow.4.{{.*}}(target("dx.CBuffer", {{.*}})) [[CB]], i32 1)
19+
; CHECK: [[CB:%.*]] = load target("dx.CBuffer", %__cblayout_CB), ptr @CB.cb
20+
; CHECK: [[LOAD:%.*]] = call { float, float, float, float } @llvm.dx.resource.load.cbufferrow.4.{{.*}}(target("dx.CBuffer", %__cblayout_CB) [[CB]], i32 1)
2121
; CHECK: [[X:%.*]] = extractvalue { float, float, float, float } [[LOAD]], 0
2222
; CHECK: store float [[X]], ptr %dst
23-
%a1 = load float, ptr addrspace(2) getelementptr inbounds ([3 x float], ptr addrspace(2) @a1, i32 0, i32 1), align 4
23+
%a1 = load float, ptr addrspace(2) getelementptr inbounds (<{ [2 x <{ float, [12 x i8] }>], float }>, ptr addrspace(2) @a1, i32 0, i32 0, i32 1), align 4
2424
store float %a1, ptr %dst, align 32
2525

2626
ret void

0 commit comments

Comments
 (0)