Skip to content

Commit 8cebb59

Browse files
committed
[HLSL] Codegen for simple cbuffer blocks without embedded arrays or structs
1 parent 4e2a9e5 commit 8cebb59

15 files changed

+392
-202
lines changed

clang/include/clang/Basic/Attr.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4711,9 +4711,9 @@ def HLSLPackOffset: HLSLAnnotationAttr {
47114711
let Args = [IntArgument<"Subcomponent">, IntArgument<"Component">];
47124712
let Documentation = [HLSLPackOffsetDocs];
47134713
let AdditionalMembers = [{
4714-
unsigned getOffset() {
4715-
return subcomponent * 4 + component;
4716-
}
4714+
unsigned getOffsetInBytes() {
4715+
return subcomponent * 16 + component * 4;
4716+
}
47174717
}];
47184718
}
47194719

clang/lib/CodeGen/CGDeclCXX.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -886,6 +886,10 @@ CodeGenModule::EmitCXXGlobalInitFunc() {
886886
ModuleInits.push_back(Fn);
887887
}
888888

889+
if (getLangOpts().HLSL && getHLSLRuntime().needsResourceBindingInitFn()) {
890+
CXXGlobalInits.push_back(getHLSLRuntime().createResourceBindingInitFn());
891+
}
892+
889893
if (ModuleInits.empty() && CXXGlobalInits.empty() &&
890894
PrioritizedCXXGlobalInits.empty())
891895
return;
@@ -1127,14 +1131,6 @@ CodeGenFunction::GenerateCXXGlobalInitFunc(llvm::Function *Fn,
11271131
if (Decls[i])
11281132
EmitRuntimeCall(Decls[i]);
11291133

1130-
if (getLangOpts().HLSL) {
1131-
CGHLSLRuntime &CGHLSL = CGM.getHLSLRuntime();
1132-
if (CGHLSL.needsResourceBindingInitFn()) {
1133-
llvm::Function *ResInitFn = CGHLSL.createResourceBindingInitFn();
1134-
Builder.CreateCall(llvm::FunctionCallee(ResInitFn), {});
1135-
}
1136-
}
1137-
11381134
Scope.ForceCleanup();
11391135

11401136
if (ExitBlock) {

clang/lib/CodeGen/CGHLSLRuntime.cpp

Lines changed: 155 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -54,69 +54,110 @@ void addDxilValVersion(StringRef ValVersionStr, llvm::Module &M) {
5454
auto *DXILValMD = M.getOrInsertNamedMetadata(DXILValKey);
5555
DXILValMD->addOperand(Val);
5656
}
57+
5758
void addDisableOptimizations(llvm::Module &M) {
5859
StringRef Key = "dx.disable_optimizations";
5960
M.addModuleFlag(llvm::Module::ModFlagBehavior::Override, Key, 1);
6061
}
61-
// cbuffer will be translated into global variable in special address space.
62-
// If translate into C,
63-
// cbuffer A {
64-
// float a;
65-
// float b;
66-
// }
67-
// float foo() { return a + b; }
68-
//
69-
// will be translated into
70-
//
71-
// struct A {
72-
// float a;
73-
// float b;
74-
// } cbuffer_A __attribute__((address_space(4)));
75-
// float foo() { return cbuffer_A.a + cbuffer_A.b; }
76-
//
77-
// layoutBuffer will create the struct A type.
78-
// replaceBuffer will replace use of global variable a and b with cbuffer_A.a
79-
// and cbuffer_A.b.
80-
//
81-
void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
82-
if (Buf.Constants.empty())
83-
return;
62+
63+
// Creates the LLVM struct type representing the shape of the constant buffer
64+
// which will be included in the LLVM target type and calculates the memory
65+
// layout and constant buffer layout offsets of each constant.
66+
static void layoutBuffer(CGHLSLRuntime::Buffer &Buf, const DataLayout &DL) {
67+
assert(!Buf.Constants.empty() &&
68+
"empty constant buffer should not be created");
8469

8570
std::vector<llvm::Type *> EltTys;
86-
for (auto &Const : Buf.Constants) {
87-
GlobalVariable *GV = Const.first;
88-
Const.second = EltTys.size();
71+
unsigned MemOffset = 0, CBufOffset = 0, Size = 0;
72+
73+
for (auto &C : Buf.Constants) {
74+
GlobalVariable *GV = C.GlobalVar;
8975
llvm::Type *Ty = GV->getValueType();
76+
77+
assert(!Ty->isArrayTy() && !Ty->isStructTy() &&
78+
"arrays and structs in cbuffer are not yet implemened");
79+
80+
// scalar type, vector or matrix
9081
EltTys.emplace_back(Ty);
82+
unsigned FieldSize = Ty->getScalarSizeInBits() / 8;
83+
if (Ty->isVectorTy())
84+
FieldSize *= cast<FixedVectorType>(Ty)->getNumElements();
85+
assert(FieldSize <= 16 && "field side larger than constant buffer row");
86+
87+
// set memory layout offset (no padding)
88+
C.MemOffset = MemOffset;
89+
MemOffset += FieldSize;
90+
91+
// calculate cbuffer layout offset or update total cbuffer size from
92+
// packoffset annotations
93+
if (Buf.HasPackoffset) {
94+
assert(C.CBufferOffset != UINT_MAX &&
95+
"cbuffer offset should have been set from packoffset attribute");
96+
unsigned OffsetAfterField = C.CBufferOffset + FieldSize;
97+
if (Size < OffsetAfterField)
98+
Size = OffsetAfterField;
99+
} else {
100+
// allign to the size of the field
101+
CBufOffset = llvm::alignTo(CBufOffset, FieldSize);
102+
C.CBufferOffset = CBufOffset;
103+
CBufOffset += FieldSize;
104+
Size = CBufOffset;
105+
}
91106
}
92107
Buf.LayoutStruct = llvm::StructType::get(EltTys[0]->getContext(), EltTys);
108+
Buf.Size = Size;
93109
}
94110

95-
GlobalVariable *replaceBuffer(CGHLSLRuntime::Buffer &Buf) {
96-
// Create global variable for CB.
97-
GlobalVariable *CBGV = new GlobalVariable(
98-
Buf.LayoutStruct, /*isConstant*/ true,
99-
GlobalValue::LinkageTypes::ExternalLinkage, nullptr,
100-
llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb." : ".tb."),
101-
GlobalValue::NotThreadLocal);
111+
// Creates LLVM target type target("dx.CBuffer",..) for the constant buffer.
112+
// The target type includes the LLVM struct type representing the shape
113+
// of the constant buffer, size, and a list of offsets for each fields
114+
// in cbuffer layout.
115+
static llvm::Type *getBufferTargetType(LLVMContext &Ctx,
116+
CGHLSLRuntime::Buffer &Buf) {
117+
assert(Buf.LayoutStruct != nullptr && Buf.Size != UINT_MAX &&
118+
"the buffer layout has not been calculated yet");
119+
llvm::SmallVector<unsigned> SizeAndOffsets;
120+
SizeAndOffsets.reserve(Buf.Constants.size() + 1);
121+
SizeAndOffsets.push_back(Buf.Size);
122+
for (CGHLSLRuntime::BufferConstant &C : Buf.Constants) {
123+
SizeAndOffsets.push_back(C.CBufferOffset);
124+
}
125+
return llvm::TargetExtType::get(Ctx, "dx.CBuffer", {Buf.LayoutStruct},
126+
SizeAndOffsets);
127+
}
102128

103-
IRBuilder<> B(CBGV->getContext());
104-
Value *ZeroIdx = B.getInt32(0);
105-
// Replace Const use with CB use.
106-
for (auto &[GV, Offset] : Buf.Constants) {
107-
Value *GEP =
108-
B.CreateGEP(Buf.LayoutStruct, CBGV, {ZeroIdx, B.getInt32(Offset)});
109-
110-
assert(Buf.LayoutStruct->getElementType(Offset) == GV->getValueType() &&
111-
"constant type mismatch");
112-
113-
// Replace.
114-
GV->replaceAllUsesWith(GEP);
115-
// Erase GV.
116-
GV->removeDeadConstantUsers();
117-
GV->eraseFromParent();
129+
// Replaces all uses of the temporary constant buffer global variables with
130+
// buffer access intrinsic resource.getpointer.
131+
static void replaceBufferGlobals(CodeGenModule &CGM,
132+
CGHLSLRuntime::Buffer &Buf) {
133+
assert(Buf.IsCBuffer && "tbuffer codegen is not yet supported");
134+
135+
GlobalVariable *BufGV = Buf.GlobalVar;
136+
for (auto &Constant : Buf.Constants) {
137+
GlobalVariable *ConstGV = Constant.GlobalVar;
138+
139+
// TODO: Map to an hlsl_device address space.
140+
llvm::Type *RetTy = ConstGV->getType();
141+
llvm::Type *TargetTy = BufGV->getValueType();
142+
143+
// Replace all uses of GV with CBuffer access
144+
while (ConstGV->use_begin() != ConstGV->use_end()) {
145+
Use &U = *ConstGV->use_begin();
146+
if (Instruction *UserInstr = dyn_cast<Instruction>(U.getUser())) {
147+
IRBuilder<> Builder(UserInstr);
148+
Value *Handle = Builder.CreateLoad(TargetTy, BufGV);
149+
Value *ResGetPointer = Builder.CreateIntrinsic(
150+
RetTy, Intrinsic::dx_resource_getpointer,
151+
ArrayRef<llvm::Value *>{Handle,
152+
Builder.getInt32(Constant.MemOffset)});
153+
U.set(ResGetPointer);
154+
} else {
155+
llvm_unreachable("unexpected use of constant value");
156+
}
157+
}
158+
ConstGV->removeDeadConstantUsers();
159+
ConstGV->eraseFromParent();
118160
}
119-
return CBGV;
120161
}
121162

122163
} // namespace
@@ -143,20 +184,22 @@ void CGHLSLRuntime::addConstant(VarDecl *D, Buffer &CB) {
143184
return;
144185
}
145186

187+
assert(!D->getType()->isArrayType() && !D->getType()->isStructureType() &&
188+
"codegen for arrays and structs in cbuffer is not yet supported");
189+
146190
auto *GV = cast<GlobalVariable>(CGM.GetAddrOfGlobalVar(D));
147191
// Add debug info for constVal.
148192
if (CGDebugInfo *DI = CGM.getModuleDebugInfo())
149193
if (CGM.getCodeGenOpts().getDebugInfo() >=
150194
codegenoptions::DebugInfoKind::LimitedDebugInfo)
151195
DI->EmitGlobalVariable(cast<GlobalVariable>(GV), D);
152196

153-
// FIXME: support packoffset.
154-
// See https://github.com/llvm/llvm-project/issues/57914.
155-
uint32_t Offset = 0;
156-
bool HasUserOffset = false;
197+
CB.Constants.emplace_back(GV);
157198

158-
unsigned LowerBound = HasUserOffset ? Offset : UINT_MAX;
159-
CB.Constants.emplace_back(std::make_pair(GV, LowerBound));
199+
if (HLSLPackOffsetAttr *PO = D->getAttr<HLSLPackOffsetAttr>()) {
200+
CB.HasPackoffset = true;
201+
CB.Constants.back().CBufferOffset = PO->getOffsetInBytes();
202+
}
160203
}
161204

162205
void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
@@ -173,9 +216,37 @@ void CGHLSLRuntime::addBufferDecls(const DeclContext *DC, Buffer &CB) {
173216
}
174217
}
175218

219+
// Creates temporary global variables for all declarations within the constant
220+
// buffer context, calculates the buffer layouts, and then creates a global
221+
// variable for the constant buffer and adds it to the module.
222+
// All uses of the temporary constant globals will be replaced with buffer
223+
// access intrinsic resource.getpointer in CGHLSLRuntime::finishCodeGen.
224+
// Later on in DXILResourceAccess pass these will be transtaled
225+
// to dx.op.cbufferLoadLegacy instructions.
176226
void CGHLSLRuntime::addBuffer(const HLSLBufferDecl *D) {
177-
Buffers.emplace_back(Buffer(D));
178-
addBufferDecls(D, Buffers.back());
227+
llvm::Module &M = CGM.getModule();
228+
const DataLayout &DL = M.getDataLayout();
229+
230+
assert(D->isCBuffer() && "tbuffer codegen is not supported yet");
231+
232+
Buffer &Buf = Buffers.emplace_back(D);
233+
addBufferDecls(D, Buf);
234+
if (Buf.Constants.empty()) {
235+
// empty constant buffer - do not add to globals
236+
Buffers.pop_back();
237+
return;
238+
}
239+
layoutBuffer(Buf, DL);
240+
241+
// Create global variable for CB.
242+
llvm::Type *TargetTy = getBufferTargetType(CGM.getLLVMContext(), Buf);
243+
Buf.GlobalVar = new GlobalVariable(
244+
TargetTy, /*isConstant*/ true, GlobalValue::LinkageTypes::ExternalLinkage,
245+
nullptr, llvm::formatv("{0}{1}", Buf.Name, Buf.IsCBuffer ? ".cb" : ".tb"),
246+
GlobalValue::NotThreadLocal);
247+
248+
M.insertGlobalVariable(Buf.GlobalVar);
249+
ResourcesToBind.emplace_back(Buf.Decl, Buf.GlobalVar);
179250
}
180251

181252
void CGHLSLRuntime::finishCodeGen() {
@@ -189,26 +260,14 @@ void CGHLSLRuntime::finishCodeGen() {
189260
if (CGM.getCodeGenOpts().OptimizationLevel == 0)
190261
addDisableOptimizations(M);
191262

192-
const DataLayout &DL = M.getDataLayout();
193-
194263
for (auto &Buf : Buffers) {
195-
layoutBuffer(Buf, DL);
196-
GlobalVariable *GV = replaceBuffer(Buf);
197-
M.insertGlobalVariable(GV);
198-
llvm::hlsl::ResourceClass RC = Buf.IsCBuffer
199-
? llvm::hlsl::ResourceClass::CBuffer
200-
: llvm::hlsl::ResourceClass::SRV;
201-
llvm::hlsl::ResourceKind RK = Buf.IsCBuffer
202-
? llvm::hlsl::ResourceKind::CBuffer
203-
: llvm::hlsl::ResourceKind::TBuffer;
204-
addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false,
205-
llvm::hlsl::ElementType::Invalid, Buf.Binding);
264+
replaceBufferGlobals(CGM, Buf);
206265
}
207266
}
208267

209268
CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D)
210-
: Name(D->getName()), IsCBuffer(D->isCBuffer()),
211-
Binding(D->getAttr<HLSLResourceBindingAttr>()) {}
269+
: Name(D->getName()), IsCBuffer(D->isCBuffer()), HasPackoffset(false),
270+
LayoutStruct(nullptr), Decl(D), GlobalVar(nullptr) {}
212271

213272
void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
214273
llvm::hlsl::ResourceClass RC,
@@ -237,7 +296,7 @@ void CGHLSLRuntime::addBufferResourceAnnotation(llvm::GlobalVariable *GV,
237296
"ResourceMD must have been set by the switch above.");
238297

239298
llvm::hlsl::FrontendResource Res(
240-
GV, RK, ET, IsROV, Binding.Reg.value_or(UINT_MAX), Binding.Space);
299+
GV, RK, ET, IsROV, Binding.Slot.value_or(UINT_MAX), Binding.Space);
241300
ResourceMD->addOperand(Res.getMetadata());
242301
}
243302

@@ -328,12 +387,8 @@ void CGHLSLRuntime::annotateHLSLResource(const VarDecl *D, GlobalVariable *GV) {
328387
CGHLSLRuntime::BufferResBinding::BufferResBinding(
329388
HLSLResourceBindingAttr *Binding) {
330389
if (Binding) {
331-
llvm::APInt RegInt(64, 0);
332-
Binding->getSlot().substr(1).getAsInteger(10, RegInt);
333-
Reg = RegInt.getLimitedValue();
334-
llvm::APInt SpaceInt(64, 0);
335-
Binding->getSpace().substr(5).getAsInteger(10, SpaceInt);
336-
Space = SpaceInt.getLimitedValue();
390+
Slot = Binding->getSlotNumber();
391+
Space = Binding->getSpaceNumber();
337392
} else {
338393
Space = 0;
339394
}
@@ -572,24 +627,30 @@ llvm::Function *CGHLSLRuntime::createResourceBindingInitFn() {
572627
const DataLayout &DL = CGM.getModule().getDataLayout();
573628
Builder.SetInsertPoint(EntryBB);
574629

575-
for (const auto &[VD, GV] : ResourcesToBind) {
576-
for (Attr *A : VD->getAttrs()) {
630+
for (const auto &[Decl, GV] : ResourcesToBind) {
631+
for (Attr *A : Decl->getAttrs()) {
577632
HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A);
578633
if (!RBA)
579634
continue;
580635

581-
const HLSLAttributedResourceType *AttrResType =
582-
HLSLAttributedResourceType::findHandleTypeOnResource(
583-
VD->getType().getTypePtr());
584-
585-
// FIXME: Only simple declarations of resources are supported for now.
586-
// Arrays of resources or resources in user defined classes are
587-
// not implemented yet.
588-
assert(AttrResType != nullptr &&
589-
"Resource class must have a handle of HLSLAttributedResourceType");
590-
591-
llvm::Type *TargetTy =
592-
CGM.getTargetCodeGenInfo().getHLSLType(CGM, AttrResType);
636+
llvm::Type *TargetTy = nullptr;
637+
if (const VarDecl *VD = dyn_cast<VarDecl>(Decl)) {
638+
const HLSLAttributedResourceType *AttrResType =
639+
HLSLAttributedResourceType::findHandleTypeOnResource(
640+
VD->getType().getTypePtr());
641+
642+
// FIXME: Only simple declarations of resources are supported for now.
643+
// Arrays of resources or resources in user defined classes are
644+
// not implemented yet.
645+
assert(
646+
AttrResType != nullptr &&
647+
"Resource class must have a handle of HLSLAttributedResourceType");
648+
649+
TargetTy = CGM.getTargetCodeGenInfo().getHLSLType(CGM, AttrResType);
650+
} else {
651+
assert(isa<HLSLBufferDecl>(Decl));
652+
TargetTy = GV->getValueType();
653+
}
593654
assert(TargetTy != nullptr &&
594655
"Failed to convert resource handle to target type");
595656

@@ -604,7 +665,7 @@ llvm::Function *CGHLSLRuntime::createResourceBindingInitFn() {
604665

605666
llvm::Value *CreateHandle = Builder.CreateIntrinsic(
606667
/*ReturnType=*/TargetTy, getCreateHandleFromBindingIntrinsic(), Args,
607-
nullptr, Twine(VD->getName()).concat("_h"));
668+
nullptr, Twine(Decl->getName()).concat("_h"));
608669

609670
llvm::Value *HandleRef =
610671
Builder.CreateStructGEP(GV->getValueType(), GV, 0);

0 commit comments

Comments
 (0)