From a3b3ba874a827747aba7cd2370bfed5f18d5a29c Mon Sep 17 00:00:00 2001 From: Helena Kotas Date: Mon, 9 Dec 2024 18:08:32 -0800 Subject: [PATCH] [HLSL] Make explicit _init_resource_bindings call for each entry point function --- clang/lib/CodeGen/CGDeclCXX.cpp | 8 -- clang/lib/CodeGen/CGHLSLRuntime.cpp | 90 +++++++++++++------ clang/lib/CodeGen/CGHLSLRuntime.h | 14 +-- .../ByteAddressBuffers-constructors.hlsl | 16 ++-- .../builtins/RWBuffer-constructor.hlsl | 17 ++-- .../StructuredBuffers-constructors.hlsl | 18 ++-- clang/test/CodeGenHLSL/resource-bindings.hlsl | 5 +- 7 files changed, 111 insertions(+), 57 deletions(-) diff --git a/clang/lib/CodeGen/CGDeclCXX.cpp b/clang/lib/CodeGen/CGDeclCXX.cpp index 2c3054605ee75..cacb0555fe086 100644 --- a/clang/lib/CodeGen/CGDeclCXX.cpp +++ b/clang/lib/CodeGen/CGDeclCXX.cpp @@ -1127,14 +1127,6 @@ CodeGenFunction::GenerateCXXGlobalInitFunc(llvm::Function *Fn, if (Decls[i]) EmitRuntimeCall(Decls[i]); - if (getLangOpts().HLSL) { - CGHLSLRuntime &CGHLSL = CGM.getHLSLRuntime(); - if (CGHLSL.needsResourceBindingInitFn()) { - llvm::Function *ResInitFn = CGHLSL.createResourceBindingInitFn(); - Builder.CreateCall(llvm::FunctionCallee(ResInitFn), {}); - } - } - Scope.ForceCleanup(); if (ExitBlock) { diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index 2c293523fca8c..6ea93453c81fc 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -204,6 +204,7 @@ void CGHLSLRuntime::finishCodeGen() { addBufferResourceAnnotation(GV, RC, RK, /*IsROV=*/false, llvm::hlsl::ElementType::Invalid, Buf.Binding); } + generateInitResBindingsFuncBody(); } CGHLSLRuntime::Buffer::Buffer(const HLSLBufferDecl *D) @@ -510,6 +511,10 @@ void CGHLSLRuntime::generateGlobalCtorDtorCalls() { for (auto *Fn : CtorFns) B.CreateCall(FunctionCallee(Fn), {}, OB); + // Insert a call to initialize resource handles from bindings + Function *ResInitFn = getOrCreateResourceBindingInitFn(); + B.CreateCall(ResInitFn); + // Insert global dtors before the terminator of the last instruction B.SetInsertPoint(F.back().getTerminator()); for (auto *Fn : DtorFns) @@ -545,46 +550,80 @@ void CGHLSLRuntime::handleGlobalVarDefinition(const VarDecl *VD, ResourcesToBind.emplace_back(VD, GV); } -bool CGHLSLRuntime::needsResourceBindingInitFn() { - return !ResourcesToBind.empty(); +// Creates a declaration of _init_resource_bindings function. It will later be +// populated with calls to initialize resource handles from bindings. +llvm::Function *CGHLSLRuntime::getOrCreateResourceBindingInitFn() { + if (InitResBindingsFunc == nullptr) { + InitResBindingsFunc = + llvm::Function::Create(llvm::FunctionType::get(CGM.VoidTy, false), + llvm::GlobalValue::InternalLinkage, + "_init_resource_bindings", CGM.getModule()); + } + return InitResBindingsFunc; +} + +void CGHLSLRuntime::removeInitResBindingsFunc() { + if (!InitResBindingsFunc) + return; + while (InitResBindingsFunc->user_begin() != InitResBindingsFunc->user_end()) { + User *U = *InitResBindingsFunc->user_begin(); + assert(isa(U)); + llvm::CallInst *CI = cast(U); + CI->eraseFromParent(); + } + InitResBindingsFunc->eraseFromParent(); + InitResBindingsFunc = nullptr; } -llvm::Function *CGHLSLRuntime::createResourceBindingInitFn() { - // No resources to bind - assert(needsResourceBindingInitFn() && "no resources to bind"); +// Populates the body of _init_resource_bindings function with calls to +// initialize resource handles from bindings. If there are no resources to bind +// it will remove the function and all of its calls. +void CGHLSLRuntime::generateInitResBindingsFuncBody() { + if (ResourcesToBind.empty()) { + removeInitResBindingsFunc(); + return; + } + + if (InitResBindingsFunc == nullptr) { + // FIXME: resource init function did not get created in shader entry point. + // Is this a library or just a shader without an entry function? + // llvm-project/llvm#119260 + return; + } LLVMContext &Ctx = CGM.getLLVMContext(); llvm::Type *Int1Ty = llvm::Type::getInt1Ty(Ctx); - llvm::Function *InitResBindingsFunc = - llvm::Function::Create(llvm::FunctionType::get(CGM.VoidTy, false), - llvm::GlobalValue::InternalLinkage, - "_init_resource_bindings", CGM.getModule()); - llvm::BasicBlock *EntryBB = llvm::BasicBlock::Create(Ctx, "entry", InitResBindingsFunc); CGBuilderTy Builder(CGM, Ctx); const DataLayout &DL = CGM.getModule().getDataLayout(); Builder.SetInsertPoint(EntryBB); - for (const auto &[VD, GV] : ResourcesToBind) { - for (Attr *A : VD->getAttrs()) { + for (const auto &[Decl, GV] : ResourcesToBind) { + for (Attr *A : Decl->getAttrs()) { HLSLResourceBindingAttr *RBA = dyn_cast(A); if (!RBA) continue; - const HLSLAttributedResourceType *AttrResType = - HLSLAttributedResourceType::findHandleTypeOnResource( - VD->getType().getTypePtr()); - - // FIXME: Only simple declarations of resources are supported for now. - // Arrays of resources or resources in user defined classes are - // not implemented yet. - assert(AttrResType != nullptr && - "Resource class must have a handle of HLSLAttributedResourceType"); - - llvm::Type *TargetTy = - CGM.getTargetCodeGenInfo().getHLSLType(CGM, AttrResType); + llvm::Type *TargetTy = nullptr; + if (const VarDecl *VD = dyn_cast(Decl)) { + const HLSLAttributedResourceType *AttrResType = + HLSLAttributedResourceType::findHandleTypeOnResource( + VD->getType().getTypePtr()); + + // FIXME: Only simple declarations of resources are supported for now. + // Arrays of resources or resources in user defined classes are + // not implemented yet. + assert( + AttrResType != nullptr && + "Resource class must have a handle of HLSLAttributedResourceType"); + + TargetTy = CGM.getTargetCodeGenInfo().getHLSLType(CGM, AttrResType); + } else { + assert(isa(Decl)); + llvm_unreachable("CBuffer codegen is not supported yet"); + } assert(TargetTy != nullptr && "Failed to convert resource handle to target type"); @@ -599,7 +638,7 @@ llvm::Function *CGHLSLRuntime::createResourceBindingInitFn() { llvm::Value *CreateHandle = Builder.CreateIntrinsic( /*ReturnType=*/TargetTy, getCreateHandleFromBindingIntrinsic(), Args, - nullptr, Twine(VD->getName()).concat("_h")); + nullptr, Twine(Decl->getName()).concat("_h")); llvm::Value *HandleRef = Builder.CreateStructGEP(GV->getValueType(), GV, 0); @@ -609,7 +648,6 @@ llvm::Function *CGHLSLRuntime::createResourceBindingInitFn() { } Builder.CreateRetVoid(); - return InitResBindingsFunc; } llvm::Instruction *CGHLSLRuntime::getConvergenceToken(BasicBlock &BB) { diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h index bb120c8b5e9e6..237896ccad0e2 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -54,6 +54,7 @@ class StructType; namespace clang { class VarDecl; +class NamedDecl; class ParmVarDecl; class HLSLBufferDecl; class HLSLResourceBindingAttr; @@ -136,7 +137,7 @@ class CGHLSLRuntime { llvm::Type *Ty); public: - CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM) {} + CGHLSLRuntime(CodeGenModule &CGM) : CGM(CGM), InitResBindingsFunc(nullptr) {} virtual ~CGHLSLRuntime() {} llvm::Type *convertHLSLSpecificType(const Type *T); @@ -153,8 +154,6 @@ class CGHLSLRuntime { void setHLSLFunctionAttributes(const FunctionDecl *FD, llvm::Function *Fn); void handleGlobalVarDefinition(const VarDecl *VD, llvm::GlobalVariable *Var); - bool needsResourceBindingInitFn(); - llvm::Function *createResourceBindingInitFn(); llvm::Instruction *getConvergenceToken(llvm::BasicBlock &BB); private: @@ -165,10 +164,15 @@ class CGHLSLRuntime { BufferResBinding &Binding); void addConstant(VarDecl *D, Buffer &CB); void addBufferDecls(const DeclContext *DC, Buffer &CB); + + llvm::Function *getOrCreateResourceBindingInitFn(); + void generateInitResBindingsFuncBody(); + void removeInitResBindingsFunc(); + llvm::Triple::ArchType getArch(); llvm::SmallVector Buffers; - - llvm::SmallVector> + llvm::Function *InitResBindingsFunc; + llvm::SmallVector> ResourcesToBind; }; diff --git a/clang/test/CodeGenHLSL/builtins/ByteAddressBuffers-constructors.hlsl b/clang/test/CodeGenHLSL/builtins/ByteAddressBuffers-constructors.hlsl index 45e135427ba9c..efbf80077225a 100644 --- a/clang/test/CodeGenHLSL/builtins/ByteAddressBuffers-constructors.hlsl +++ b/clang/test/CodeGenHLSL/builtins/ByteAddressBuffers-constructors.hlsl @@ -1,5 +1,5 @@ -// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL -// RUN-DISABLED: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL +// RUN-DISABLED: %clang_cc1 -triple spirv-vulkan-compute -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV // NOTE: SPIRV codegen for resource types is not yet implemented @@ -15,9 +15,12 @@ RasterizerOrderedByteAddressBuffer Buffer2: register(u3, space4); // CHECK: @Buffer1 = global %"class.hlsl::RWByteAddressBuffer" zeroinitializer, align 4 // CHECK: @Buffer2 = global %"class.hlsl::RasterizerOrderedByteAddressBuffer" zeroinitializer, align 4 -// CHECK: define internal void @_GLOBAL__sub_I_ByteAddressBuffers_constructors.hlsl() -// CHECK: entry: -// CHECK: call void @_init_resource_bindings() +// CHECK: define void @main() +// CHECK-NEXT: entry: +// CHECK-NEXT: call void @_GLOBAL__sub_I_ByteAddressBuffers_constructors.hlsl() +// CHECK-NEXT: call void @_init_resource_bindings() +// CHECK-NEXT: call void @_Z4mainv() +// CHECK-NEXT: ret void // CHECK: define internal void @_init_resource_bindings() { // CHECK-NEXT: entry: @@ -27,3 +30,6 @@ RasterizerOrderedByteAddressBuffer Buffer2: register(u3, space4); // CHECK-DXIL-NEXT: store target("dx.RawBuffer", i8, 1, 0) %Buffer1_h, ptr @Buffer1, align 4 // CHECK-DXIL-NEXT: %Buffer2_h = call target("dx.RawBuffer", i8, 1, 1) @llvm.dx.handle.fromBinding.tdx.RawBuffer_i8_1_1t(i32 4, i32 3, i32 1, i32 0, i1 false) // CHECK-DXIL-NEXT: store target("dx.RawBuffer", i8, 1, 1) %Buffer2_h, ptr @Buffer2, align 4 + +[numthreads(4,1,1)] +void main() {} diff --git a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl index c2db56e2b2bdd..46ce5a227b78c 100644 --- a/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl +++ b/clang/test/CodeGenHLSL/builtins/RWBuffer-constructor.hlsl @@ -1,6 +1,6 @@ -// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL // FIXME: SPIR-V codegen of llvm.spv.handle.fromBinding is not yet implemented -// RUN-DISABLED: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV +// RUN-DISABLED: %clang_cc1 -triple spirv-vulkan-compute -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV // NOTE: SPIRV codegen for resource types is not yet implemented @@ -9,13 +9,15 @@ RWBuffer Buf : register(u5, space3); // CHECK: %"class.hlsl::RWBuffer" = type { target("dx.TypedBuffer", float, 1, 0, 0) } // CHECK: @Buf = global %"class.hlsl::RWBuffer" zeroinitializer, align 4 -// CHECK: define linkonce_odr void @_ZN4hlsl8RWBufferIfEC2Ev(ptr noundef nonnull align 4 dereferenceable(4) %this) +// CHECK: define void @main() // CHECK-NEXT: entry: +// CHECK-NEXT: call void @_GLOBAL__sub_I_RWBuffer_constructor.hlsl() +// CHECK-NEXT: call void @_init_resource_bindings() +// CHECK-NEXT: call void @_Z4mainv() +// CHECK-NEXT: ret void -// CHECK: define internal void @_GLOBAL__sub_I_RWBuffer_constructor.hlsl() +// CHECK: define linkonce_odr void @_ZN4hlsl8RWBufferIfEC2Ev(ptr noundef nonnull align 4 dereferenceable(4) %this) // CHECK-NEXT: entry: -// CHECK-NEXT: call void @__cxx_global_var_init() -// CHECK-NEXT: call void @_init_resource_bindings() // CHECK: define internal void @_init_resource_bindings() { // CHECK-NEXT: entry: @@ -23,3 +25,6 @@ RWBuffer Buf : register(u5, space3); // CHECK-DXIL-NEXT: store target("dx.TypedBuffer", float, 1, 0, 0) %Buf_h, ptr @Buf, align 4 // CHECK-SPIRV-NEXT: %Buf_h = call target("dx.TypedBuffer", float, 1, 0, 0) @llvm.spv.handle.fromBinding.tdx.TypedBuffer_f32_1_0_0t(i32 3, i32 5, i32 1, i32 0, i1 false) // CHECK-SPIRV-NEXT: store target("dx.TypedBuffer", float, 1, 0, 0) %Buf_h, ptr @Buf, align 4 + +[numthreads(4,1,1)] +void main() {} diff --git a/clang/test/CodeGenHLSL/builtins/StructuredBuffers-constructors.hlsl b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-constructors.hlsl index d84e92242ffb4..ff31ac48616fe 100644 --- a/clang/test/CodeGenHLSL/builtins/StructuredBuffers-constructors.hlsl +++ b/clang/test/CodeGenHLSL/builtins/StructuredBuffers-constructors.hlsl @@ -1,5 +1,5 @@ -// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL -// RUN-DISABLED: %clang_cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-compute -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL +// RUN-DISABLED: %clang_cc1 -triple spirv-vulkan-compute -emit-llvm -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV // NOTE: SPIRV codegen for resource types is not yet implemented @@ -21,6 +21,13 @@ RasterizerOrderedStructuredBuffer Buf5 : register(u1, space2); // CHECK: @Buf4 = global %"class.hlsl::ConsumeStructuredBuffer" zeroinitializer, align 4 // CHECK: @Buf5 = global %"class.hlsl::RasterizerOrderedStructuredBuffer" zeroinitializer, align 4 +// CHECK: define void @main() +// CHECK-NEXT: entry: +// CHECK-NEXT: call void @_GLOBAL__sub_I_StructuredBuffers_constructors.hlsl() +// CHECK-NEXT: call void @_init_resource_bindings() +// CHECK-NEXT: call void @_Z4mainv() +// CHECK-NEXT: ret void + // CHECK: define linkonce_odr void @_ZN4hlsl16StructuredBufferIfEC2Ev(ptr noundef nonnull align 4 dereferenceable(4) %this) // CHECK-NEXT: entry: // CHECK: define linkonce_odr void @_ZN4hlsl18RWStructuredBufferIfEC2Ev(ptr noundef nonnull align 4 dereferenceable(4) %this) @@ -31,10 +38,6 @@ RasterizerOrderedStructuredBuffer Buf5 : register(u1, space2); // CHECK: define linkonce_odr void @_ZN4hlsl33RasterizerOrderedStructuredBufferIfEC2Ev(ptr noundef nonnull align 4 dereferenceable(4) %this) // CHECK-NEXT: entry: -// CHECK: define internal void @_GLOBAL__sub_I_StructuredBuffers_constructors.hlsl() -// CHECK: entry: -// CHECK: call void @_init_resource_bindings() - // CHECK: define internal void @_init_resource_bindings() { // CHECK-NEXT: entry: // CHECK-DXIL-NEXT: %Buf_h = call target("dx.RawBuffer", float, 0, 0) @llvm.dx.handle.fromBinding.tdx.RawBuffer_f32_0_0t(i32 0, i32 10, i32 1, i32 0, i1 false) @@ -58,3 +61,6 @@ RasterizerOrderedStructuredBuffer Buf5 : register(u1, space2); // CHECK-SPIRV-NEXT: store target("dx.RawBuffer", float, 1, 0) %Buf4_h, ptr @Buf4, align 4 // CHECK-SPIRV-NEXT: %Buf5_h = call target("dx.RawBuffer", float, 1, 1) @llvm.spv.handle.fromBinding.tdx.RawBuffer_f32_1_1t(i32 2, i32 1, i32 1, i32 0, i1 false) // CHECK-SPIRV-NEXT: store target("dx.RawBuffer", float, 1, 1) %Buf5_h, ptr @Buf5, align 4 + +[numthreads(4,1,1)] +void main() {} diff --git a/clang/test/CodeGenHLSL/resource-bindings.hlsl b/clang/test/CodeGenHLSL/resource-bindings.hlsl index bfec90e1871f8..8f6df7d6fac11 100644 --- a/clang/test/CodeGenHLSL/resource-bindings.hlsl +++ b/clang/test/CodeGenHLSL/resource-bindings.hlsl @@ -1,4 +1,4 @@ -// RUN: %clang_cc1 -triple dxil--shadermodel6.6-compute -x hlsl -finclude-default-header -emit-llvm -o - %s | FileCheck %s +// RUN: %clang_cc1 -triple dxil--shadermodel6.6-compute -finclude-default-header -emit-llvm -o - %s | FileCheck %s // CHECK: define internal void @_init_resource_bindings() { @@ -17,3 +17,6 @@ struct S { // CHECK: %T3S0_h = call target("dx.RawBuffer", %struct.S = type { <4 x float>, i32, [12 x i8] }, 0, 0) @llvm.dx.handle.fromBinding.tdx.RawBuffer_s_struct.Ss_0_0t(i32 0, i32 3, i32 1, i32 0, i1 false) StructuredBuffer T3S0 : register(t3); + +[numthreads(4,1,1)] +void main() {}