Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion llvm/lib/Frontend/HLSL/CBuffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@ using namespace llvm::hlsl;

static size_t getMemberOffset(GlobalVariable *Handle, size_t Index) {
auto *HandleTy = cast<TargetExtType>(Handle->getValueType());
assert(HandleTy->getName().ends_with(".CBuffer") && "Not a cbuffer type");
assert((HandleTy->getName().ends_with(".CBuffer") ||
HandleTy->getName() == "spirv.VulkanBuffer") &&
"Not a cbuffer type");
assert(HandleTy->getNumTypeParameters() == 1 && "Expected layout type");

auto *LayoutTy = cast<TargetExtType>(HandleTy->getTypeParameter(0));
Expand Down
4 changes: 3 additions & 1 deletion llvm/lib/Target/SPIRV/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ add_llvm_target(SPIRVCodeGen
SPIRVTargetTransformInfo.cpp
SPIRVUtils.cpp
SPIRVEmitNonSemanticDI.cpp
SPIRVCBufferAccess.cpp

LINK_COMPONENTS
Analysis
Expand All @@ -57,8 +58,9 @@ add_llvm_target(SPIRVCodeGen
Core
Demangle
GlobalISel
SPIRVAnalysis
FrontendHLSL
MC
SPIRVAnalysis
SPIRVDesc
SPIRVInfo
ScalarOpts
Expand Down
2 changes: 2 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRV.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class RegisterBankInfo;

ModulePass *createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM);
FunctionPass *createSPIRVStructurizerPass();
ModulePass *createSPIRVCBufferAccessLegacyPass();
FunctionPass *createSPIRVMergeRegionExitTargetsPass();
FunctionPass *createSPIRVStripConvergenceIntrinsicsPass();
ModulePass *createSPIRVLegalizeImplicitBindingPass();
Expand All @@ -43,6 +44,7 @@ void initializeSPIRVPreLegalizerPass(PassRegistry &);
void initializeSPIRVPreLegalizerCombinerPass(PassRegistry &);
void initializeSPIRVPostLegalizerPass(PassRegistry &);
void initializeSPIRVStructurizerPass(PassRegistry &);
void initializeSPIRVCBufferAccessLegacyPass(PassRegistry &);
void initializeSPIRVEmitIntrinsicsPass(PassRegistry &);
void initializeSPIRVEmitNonSemanticDIPass(PassRegistry &);
void initializeSPIRVLegalizePointerCastPass(PassRegistry &);
Expand Down
172 changes: 172 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVCBufferAccess.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
//===- SPIRVCBufferAccess.cpp - Translate CBuffer Loads
//--------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass replaces all accesses to constant buffer global variables with
// accesses to the proper SPIR-V resource. It's designed to run after the
// DXIL preparation passes and before the main SPIR-V legalization passes.
//
// The pass operates as follows:
// 1. It finds all constant buffers by looking for the `!hlsl.cbs` metadata.
// 2. For each cbuffer, it finds the global variable holding the resource handle
// and the global variables for each of the cbuffer's members.
// 3. For each member variable, it creates a call to the
// `llvm.spv.resource.getpointer` intrinsic. This intrinsic takes the
// resource handle and the member's index within the cbuffer as arguments.
// The result is a pointer to that member within the SPIR-V resource.
// 4. It then replaces all uses of the original member global variable with the
// pointer returned by the `getpointer` intrinsic. This effectively retargets
// all loads and GEPs to the new resource pointer.
// 5. Finally, it cleans up by deleting the original global variables and the
// `!hlsl.cbs` metadata.
//
// This approach allows subsequent passes, like SPIRVEmitIntrinsics, to
// correctly handle GEPs that operate on the result of the `getpointer` call,
// folding them into a single OpAccessChain instruction.
//
//===----------------------------------------------------------------------===//

#include "SPIRVCBufferAccess.h"
#include "SPIRV.h"
#include "llvm/Frontend/HLSL/CBuffer.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/IntrinsicsSPIRV.h"
#include "llvm/IR/Module.h"

#define DEBUG_TYPE "spirv-cbuffer-access"
using namespace llvm;

// Finds the single instruction that defines the resource handle. This is
// typically a call to `llvm.spv.resource.handlefrombinding`.
static Instruction *findHandleDef(GlobalVariable *HandleVar) {
for (User *U : HandleVar->users()) {
if (auto *SI = dyn_cast<StoreInst>(U)) {
if (auto *I = dyn_cast<Instruction>(SI->getValueOperand())) {
return I;
}
}
}
return nullptr;
}

static bool replaceCBufferAccesses(Module &M) {
std::optional<hlsl::CBufferMetadata> CBufMD = hlsl::CBufferMetadata::get(M);
if (!CBufMD)
return false;

for (const hlsl::CBufferMapping &Mapping : *CBufMD) {
Instruction *HandleDef = findHandleDef(Mapping.Handle);
if (!HandleDef) {
// If there's no handle definition, it might be because the cbuffer is
// unused. In this case, we can just clean up the globals.
if (Mapping.Handle->use_empty()) {
for (const auto &Member : Mapping.Members) {
if (Member.GV->use_empty()) {
Member.GV->eraseFromParent();
}
}
Mapping.Handle->eraseFromParent();
}
continue;
}

// The handle definition should dominate all uses of the cbuffer members.
// We'll insert our getpointer calls right after it.
IRBuilder<> Builder(HandleDef->getNextNode());

for (uint32_t Index = 0; Index < Mapping.Members.size(); ++Index) {
GlobalVariable *MemberGV = Mapping.Members[Index].GV;
if (MemberGV->use_empty()) {
continue;
}

// Create the getpointer intrinsic call.
Value *IndexVal = Builder.getInt32(Index);
Type *PtrType = MemberGV->getType();
Value *GetPointerCall = Builder.CreateIntrinsic(
PtrType, Intrinsic::spv_resource_getpointer, {HandleDef, IndexVal});

// We cannot use replaceAllUsesWith here because some uses may be
// ConstantExprs, which cannot be replaced with non-constants.
SmallVector<User *, 4> Users(MemberGV->users());
for (User *U : Users) {
if (auto *CE = dyn_cast<ConstantExpr>(U)) {
SmallVector<Instruction *, 4> Insts;
std::function<void(ConstantExpr *)> findInstructions =
[&](ConstantExpr *Const) {
for (User *ConstU : Const->users()) {
if (auto *ConstCE = dyn_cast<ConstantExpr>(ConstU)) {
findInstructions(ConstCE);
} else if (auto *I = dyn_cast<Instruction>(ConstU)) {
Insts.push_back(I);
}
}
};
findInstructions(CE);

for (Instruction *I : Insts) {
Instruction *NewInst = CE->getAsInstruction();
NewInst->insertBefore(I->getIterator());
I->replaceUsesOfWith(CE, NewInst);
NewInst->replaceUsesOfWith(MemberGV, GetPointerCall);
}
} else {
U->replaceUsesOfWith(MemberGV, GetPointerCall);
}
}
}
}

// Now that all uses are replaced, clean up the globals and metadata.
for (const hlsl::CBufferMapping &Mapping : *CBufMD) {
for (const auto &Member : Mapping.Members) {
Member.GV->eraseFromParent();
}
// Erase the stores to the handle variable before erasing the handle itself.
SmallVector<Instruction *, 4> HandleStores;
for (User *U : Mapping.Handle->users()) {
if (auto *SI = dyn_cast<StoreInst>(U)) {
HandleStores.push_back(SI);
}
}
for (Instruction *I : HandleStores) {
I->eraseFromParent();
}
Mapping.Handle->eraseFromParent();
}

CBufMD->eraseFromModule();
return true;
}

PreservedAnalyses SPIRVCBufferAccess::run(Module &M,
ModuleAnalysisManager &AM) {
if (replaceCBufferAccesses(M)) {
return PreservedAnalyses::none();
}
return PreservedAnalyses::all();
}

namespace {
class SPIRVCBufferAccessLegacy : public ModulePass {
public:
bool runOnModule(Module &M) override { return replaceCBufferAccesses(M); }
StringRef getPassName() const override { return "SPIRV CBuffer Access"; }
SPIRVCBufferAccessLegacy() : ModulePass(ID) {}

static char ID; // Pass identification.
};
char SPIRVCBufferAccessLegacy::ID = 0;
} // end anonymous namespace

INITIALIZE_PASS(SPIRVCBufferAccessLegacy, DEBUG_TYPE, "SPIRV CBuffer Access",
false, false)

ModulePass *llvm::createSPIRVCBufferAccessLegacyPass() {
return new SPIRVCBufferAccessLegacy();
}
23 changes: 23 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVCBufferAccess.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
//===- SPIRVCBufferAccess.cpp - Translate CBuffer Loads
//--------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//

#ifndef LLVM_LIB_TARGET_SPIRV_SPIRVCBUFFERACCESS_H_
#define LLVM_LIB_TARGET_SPIRV_SPIRVCBUFFERACCESS_H_

#include "llvm/IR/PassManager.h"

namespace llvm {

class SPIRVCBufferAccess : public PassInfoMixin<SPIRVCBufferAccess> {
public:
PreservedAnalyses run(Module &M, ModuleAnalysisManager &AM);
};

} // namespace llvm

#endif // LLVM_LIB_TARGET_SPIRV_SPIRVCBUFFERACCESS_H_
5 changes: 5 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVPassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@

// NOTE: NO INCLUDE GUARD DESIRED!

#ifndef MODULE_PASS
#define MODULE_PASS(NAME, CREATE_PASS)
#endif
MODULE_PASS("spirv-cbuffer-access", SPIRVCBufferAccess())
#undef MODULE_PASS

#ifndef FUNCTION_PASS
#define FUNCTION_PASS(NAME, CREATE_PASS)
Expand Down
4 changes: 4 additions & 0 deletions llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

#include "SPIRVTargetMachine.h"
#include "SPIRV.h"
#include "SPIRVCBufferAccess.h"
#include "SPIRVCallLowering.h"
#include "SPIRVGlobalRegistry.h"
#include "SPIRVLegalizerInfo.h"
#include "SPIRVStructurizerWrapper.h"
Expand Down Expand Up @@ -48,6 +50,7 @@ extern "C" LLVM_ABI LLVM_EXTERNAL_VISIBILITY void LLVMInitializeSPIRVTarget() {
initializeSPIRVAsmPrinterPass(PR);
initializeSPIRVConvergenceRegionAnalysisWrapperPassPass(PR);
initializeSPIRVStructurizerPass(PR);
initializeSPIRVCBufferAccessLegacyPass(PR);
initializeSPIRVPreLegalizerCombinerPass(PR);
initializeSPIRVLegalizePointerCastPass(PR);
initializeSPIRVRegularizerPass(PR);
Expand Down Expand Up @@ -206,6 +209,7 @@ void SPIRVPassConfig::addISelPrepare() {

addPass(createSPIRVStripConvergenceIntrinsicsPass());
addPass(createSPIRVLegalizeImplicitBindingPass());
addPass(createSPIRVCBufferAccessLegacyPass());
addPass(createSPIRVEmitIntrinsicsPass(&getTM<SPIRVTargetMachine>()));
if (TM.getSubtargetImpl()->isLogicalSPIRV())
addPass(createSPIRVLegalizePointerCastPass(&getTM<SPIRVTargetMachine>()));
Expand Down
49 changes: 49 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-resources/cbuffer.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv1.6-vulkan1.3-library %s -o - | FileCheck %s

; CHECK-DAG: OpDecorate %[[MyCBuffer:[0-9]+]] DescriptorSet 0
; CHECK-DAG: OpDecorate %[[MyCBuffer]] Binding 0
; CHECK-DAG: OpMemberDecorate %[[__cblayout_MyCBuffer:[0-9]+]] 0 Offset 0
; CHECK-DAG: OpMemberDecorate %[[__cblayout_MyCBuffer]] 1 Offset 16
; CHECK-DAG: %[[uint:[0-9]+]] = OpTypeInt 32 0
; CHECK-DAG: %[[uint_0:[0-9]+]] = OpConstant %[[uint]] 0{{$}}
; CHECK-DAG: %[[uint_1:[0-9]+]] = OpConstant %[[uint]] 1{{$}}
; CHECK-DAG: %[[float:[0-9]+]] = OpTypeFloat 32
; CHECK-DAG: %[[v4float:[0-9]+]] = OpTypeVector %[[float]] 4
; CHECK-DAG: %[[__cblayout_MyCBuffer]] = OpTypeStruct %[[v4float]] %[[v4float]]
; CHECK-DAG: %[[wrapper:[0-9]+]] = OpTypeStruct %[[__cblayout_MyCBuffer]]
; CHECK-DAG: %[[wrapper_ptr_t:[0-9]+]] = OpTypePointer Uniform %[[wrapper]]
; CHECK-DAG: %[[MyCBuffer]] = OpVariable %[[wrapper_ptr_t]] Uniform
; CHECK-DAG: %[[_ptr_Uniform_v4float:[0-9]+]] = OpTypePointer Uniform %[[v4float]]

%__cblayout_MyCBuffer = type <{ <4 x float>, <4 x float> }>

@MyCBuffer.cb = local_unnamed_addr global target("spirv.VulkanBuffer", target("spirv.Layout", %__cblayout_MyCBuffer, 32, 0, 16), 2, 0) poison
@a = external hidden local_unnamed_addr addrspace(12) global <4 x float>, align 16
@b = external hidden local_unnamed_addr addrspace(12) global <4 x float>, align 16
@MyCBuffer.str = private unnamed_addr constant [10 x i8] c"MyCBuffer\00", align 1
@.str = private unnamed_addr constant [7 x i8] c"output\00", align 1

; Function Attrs: mustprogress nofree noinline norecurse nosync nounwind willreturn memory(readwrite, argmem: write, inaccessiblemem: none)
define void @main() local_unnamed_addr #1 {
entry:
; CHECK: %[[tmp:[0-9]+]] = OpCopyObject %[[wrapper_ptr_t]] %[[MyCBuffer]]
%MyCBuffer.cb_h.i.i = tail call target("spirv.VulkanBuffer", target("spirv.Layout", %__cblayout_MyCBuffer, 32, 0, 16), 2, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_tspirv.Layout_s___cblayout_MyCBuffers_32_0_16t_2_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @MyCBuffer.str)
store target("spirv.VulkanBuffer", target("spirv.Layout", %__cblayout_MyCBuffer, 32, 0, 16), 2, 0) %MyCBuffer.cb_h.i.i, ptr @MyCBuffer.cb, align 8
%0 = tail call target("spirv.Image", <4 x float>, 5, 2, 0, 0, 2, 3) @llvm.spv.resource.handlefrombinding.tspirv.Image_v4f32_5_2_0_0_2_3t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str)
; CHECK: %[[a_ptr:.+]] = OpAccessChain %[[_ptr_Uniform_v4float]] %[[tmp]] %[[uint_0]] %[[uint_0]]
; CHECK: %[[b_ptr:.+]] = OpAccessChain %[[_ptr_Uniform_v4float]] %[[tmp]] %[[uint_0]] %[[uint_1]]
; CHECK: %[[a_val:.+]] = OpLoad %[[v4float]] %[[a_ptr]]
; CHECK: %[[b_val:.+]] = OpLoad %[[v4float]] %[[b_ptr]]
%a_val = load <4 x float>, ptr addrspace(12) @a, align 16
%b_val = load <4 x float>, ptr addrspace(12) @b, align 16
%add = fadd <4 x float> %a_val, %b_val
%output_ptr = tail call noundef ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.Image_v4f32_5_2_0_0_2_3t(target("spirv.Image", <4 x float>, 5, 2, 0, 0, 2, 3) %0, i32 0)
store <4 x float> %add, ptr addrspace(11) %output_ptr, align 16
ret void
}

attributes #1 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

!hlsl.cbs = !{!0}

!0 = !{ptr @MyCBuffer.cb, ptr addrspace(12) @a, ptr addrspace(12) @b}
57 changes: 57 additions & 0 deletions llvm/test/CodeGen/SPIRV/hlsl-resources/cbuffer_constant_expr.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv1.6-vulkan1.3-library %s -o - | FileCheck %s
; Test that uses of cbuffer members inside ConstantExprs are handled correctly.

; CHECK-DAG: OpDecorate %[[MyCBuffer:[0-9]+]] DescriptorSet 0
; CHECK-DAG: OpDecorate %[[MyCBuffer]] Binding 0
; CHECK-DAG: OpMemberDecorate %[[__cblayout_MyCBuffer:[0-9]+]] 0 Offset 0
; CHECK-DAG: OpMemberDecorate %[[__cblayout_MyCBuffer]] 1 Offset 16
; CHECK-DAG: %[[uint:[0-9]+]] = OpTypeInt 32 0
; CHECK-DAG: %[[uint_0:[0-9]+]] = OpConstant %[[uint]] 0{{$}}
; CHECK-DAG: %[[uint_1:[0-9]+]] = OpConstant %[[uint]] 1{{$}}
; CHECK-DAG: %[[float:[0-9]+]] = OpTypeFloat 32
; CHECK-DAG: %[[v4float:[0-9]+]] = OpTypeVector %[[float]] 4
; CHECK-DAG: %[[MyStruct:[0-9]+]] = OpTypeStruct %[[v4float]]
; CHECK-DAG: %[[__cblayout_MyCBuffer]] = OpTypeStruct %[[MyStruct]] %[[v4float]]
; CHECK-DAG: %[[wrapper:[0-9]+]] = OpTypeStruct %[[__cblayout_MyCBuffer]]
; CHECK-DAG: %[[wrapper_ptr_t:[0-9]+]] = OpTypePointer Uniform %[[wrapper]]
; CHECK-DAG: %[[MyCBuffer]] = OpVariable %[[wrapper_ptr_t]] Uniform
; CHECK-DAG: %[[_ptr_Uniform_v4float:[0-9]+]] = OpTypePointer Uniform %[[v4float]]
; CHECK-DAG: %[[_ptr_Uniform_float:[0-9]+]] = OpTypePointer Uniform %[[float]]

%MyStruct = type { <4 x float> }
%__cblayout_MyCBuffer = type <{ %MyStruct, <4 x float> }>

@MyCBuffer.cb = local_unnamed_addr global target("spirv.VulkanBuffer", target("spirv.Layout", %__cblayout_MyCBuffer, 32, 0, 16), 2, 0) poison
@s = external hidden local_unnamed_addr addrspace(12) global %MyStruct, align 16
@v = external hidden local_unnamed_addr addrspace(12) global <4 x float>, align 16
@MyCBuffer.str = private unnamed_addr constant [10 x i8] c"MyCBuffer\00", align 1
@.str = private unnamed_addr constant [7 x i8] c"output\00", align 1

define void @main() {
entry:
; CHECK: %[[tmp:[0-9]+]] = OpCopyObject %[[wrapper_ptr_t]] %[[MyCBuffer]]
%MyCBuffer.cb_h.i.i = tail call target("spirv.VulkanBuffer", target("spirv.Layout", %__cblayout_MyCBuffer, 32, 0, 16), 2, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_tspirv.Layout_s___cblayout_MyCBuffers_32_0_16t_2_0t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @MyCBuffer.str)
store target("spirv.VulkanBuffer", target("spirv.Layout", %__cblayout_MyCBuffer, 32, 0, 16), 2, 0) %MyCBuffer.cb_h.i.i, ptr @MyCBuffer.cb, align 8
%0 = tail call target("spirv.Image", float, 5, 2, 0, 0, 2, 3) @llvm.spv.resource.handlefrombinding.tspirv.Image_f32_5_2_0_0_2_3t(i32 0, i32 0, i32 1, i32 0, ptr nonnull @.str)

; This GEP is a ConstantExpr that uses @s
; CHECK: %[[tmp_ptr:[0-9]+]] = OpAccessChain {{%[0-9]+}} %[[tmp]] %[[uint_0]] %[[uint_0]]
; CHECK: %[[v_ptr:.+]] = OpAccessChain %[[_ptr_Uniform_v4float]] %[[tmp]] %[[uint_0]] %[[uint_1]]
; CHECK: %[[s_ptr_gep:[0-9]+]] = OpInBoundsAccessChain %[[_ptr_Uniform_float]] %[[tmp_ptr]] %[[uint_0]] %[[uint_1]]
%gep = getelementptr inbounds %MyStruct, ptr addrspace(12) @s, i32 0, i32 0, i32 1

; CHECK: %[[s_val:.+]] = OpLoad %[[float]] %[[s_ptr_gep]]
%load_from_gep = load float, ptr addrspace(12) %gep, align 4

; CHECK: %[[v_val:.+]] = OpLoad %[[v4float]] %[[v_ptr]]
%load_v = load <4 x float>, ptr addrspace(12) @v, align 16

%extract_v = extractelement <4 x float> %load_v, i64 0
%add = fadd float %load_from_gep, %extract_v
%get_output_ptr = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.Image_f32_5_2_0_0_2_3t(target("spirv.Image", float, 5, 2, 0, 0, 2, 3) %0, i32 0)
store float %add, ptr addrspace(11) %get_output_ptr, align 4
ret void
}

!hlsl.cbs = !{!0}
!0 = !{ptr @MyCBuffer.cb, ptr addrspace(12) @s, ptr addrspace(12) @v}
Loading
Loading