From 40067a7c6bcc48d07b23d6e651656ec553d59f04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nathan=20Gau=C3=ABr?= Date: Thu, 3 Apr 2025 15:07:25 +0200 Subject: [PATCH] [SPIR-V] Add legalize-addrspace-cast pass This commit adds a new pass in the backend which propagates the addrspace of the pointers down to the last use, making sure the addrspace remains consistent, and thus stripping any addrspacecast. This is required to lower LLVM-IR to logical SPIR-V, which does not support generic pointers. This is now required as HLSL emits several address spaces, and thus addrspacecasts in some cases: Example 1: resource access ```llvm %handle = tail call target("spirv.VulkanBuffer", ...) %rptr = @llvm.spv.resource.getpointer(%handle, ...); %cptr = addrspacecast ptr addrspace(11) %rptr to ptr %fptr = load i32, ptr %cptr ``` Example 2: object methods ```llvm define void @objectMethod(ptr %this) { } define void @foo(ptr addrspace(11) %object) { call void @objectMethod(ptr addrspacecast(addrspace(11) %object to ptr)); } ``` --- llvm/lib/Target/SPIRV/CMakeLists.txt | 1 + llvm/lib/Target/SPIRV/SPIRV.h | 1 + .../SPIRV/SPIRVLegalizeAddrspaceCast.cpp | 141 ++++++++++++++++++ llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp | 2 + .../SPIRV/pointers/pointer-addrspacecast.ll | 36 +++++ .../pointers/resource-addrspacecast-2.ll | 54 +++++++ .../SPIRV/pointers/resource-addrspacecast.ll | 37 +++++ 7 files changed, 272 insertions(+) create mode 100644 llvm/lib/Target/SPIRV/SPIRVLegalizeAddrspaceCast.cpp create mode 100644 llvm/test/CodeGen/SPIRV/pointers/pointer-addrspacecast.ll create mode 100644 llvm/test/CodeGen/SPIRV/pointers/resource-addrspacecast-2.ll create mode 100644 llvm/test/CodeGen/SPIRV/pointers/resource-addrspacecast.ll diff --git a/llvm/lib/Target/SPIRV/CMakeLists.txt b/llvm/lib/Target/SPIRV/CMakeLists.txt index 4a2b534b948d6..8cc709ab2892c 100644 --- a/llvm/lib/Target/SPIRV/CMakeLists.txt +++ b/llvm/lib/Target/SPIRV/CMakeLists.txt @@ -28,6 +28,7 @@ add_llvm_target(SPIRVCodeGen SPIRVInstructionSelector.cpp SPIRVStripConvergentIntrinsics.cpp SPIRVLegalizePointerCast.cpp + SPIRVLegalizeAddrspaceCast.cpp SPIRVMergeRegionExitTargets.cpp SPIRVISelLowering.cpp SPIRVLegalizerInfo.cpp diff --git a/llvm/lib/Target/SPIRV/SPIRV.h b/llvm/lib/Target/SPIRV/SPIRV.h index 51728d1aa678d..63b3f8ee2b467 100644 --- a/llvm/lib/Target/SPIRV/SPIRV.h +++ b/llvm/lib/Target/SPIRV/SPIRV.h @@ -24,6 +24,7 @@ FunctionPass *createSPIRVStructurizerPass(); FunctionPass *createSPIRVMergeRegionExitTargetsPass(); FunctionPass *createSPIRVStripConvergenceIntrinsicsPass(); FunctionPass *createSPIRVLegalizePointerCastPass(SPIRVTargetMachine *TM); +FunctionPass *createSPIRVLegalizeAddrspaceCastPass(SPIRVTargetMachine *TM); FunctionPass *createSPIRVRegularizerPass(); FunctionPass *createSPIRVPreLegalizerCombiner(); FunctionPass *createSPIRVPreLegalizerPass(); diff --git a/llvm/lib/Target/SPIRV/SPIRVLegalizeAddrspaceCast.cpp b/llvm/lib/Target/SPIRV/SPIRVLegalizeAddrspaceCast.cpp new file mode 100644 index 0000000000000..f99dd3c428b6a --- /dev/null +++ b/llvm/lib/Target/SPIRV/SPIRVLegalizeAddrspaceCast.cpp @@ -0,0 +1,141 @@ +//===-- SPIRVLegalizeAddrspaceCast.cpp ----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "SPIRV.h" +#include "SPIRVSubtarget.h" +#include "SPIRVTargetMachine.h" +#include "SPIRVUtils.h" +#include "llvm/CodeGen/IntrinsicLowering.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/IntrinsicInst.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsSPIRV.h" +#include "llvm/Transforms/Utils/Cloning.h" +#include "llvm/Transforms/Utils/LowerMemIntrinsics.h" + +using namespace llvm; + +namespace llvm { +void initializeSPIRVLegalizeAddrspaceCastPass(PassRegistry &); +} + +class SPIRVLegalizeAddrspaceCast : public FunctionPass { + +public: + SPIRVLegalizeAddrspaceCast(SPIRVTargetMachine *TM) + : FunctionPass(ID), TM(TM) { + initializeSPIRVLegalizeAddrspaceCastPass(*PassRegistry::getPassRegistry()); + }; + + void gatherAddrspaceCast(Function &F) { + WorkList.clear(); + std::vector ToVisit; + for (auto &BB : F) + for (auto &I : BB) + ToVisit.push_back(&I); + + std::unordered_set Visited; + while (ToVisit.size() > 0) { + User *I = ToVisit.back(); + ToVisit.pop_back(); + if (Visited.count(I) != 0) + continue; + Visited.insert(I); + + if (AddrSpaceCastInst *AI = dyn_cast(I)) + WorkList.insert(AI); + else if (auto *AO = dyn_cast(I)) + WorkList.insert(AO); + + for (auto &O : I->operands()) + if (User *U = dyn_cast(&O)) + ToVisit.push_back(U); + } + } + + void propagateAddrspace(User *U) { + if (!U->getType()->isPointerTy()) + return; + + if (AddrSpaceCastOperator *AO = dyn_cast(U)) { + for (auto &Use : AO->uses()) + WorkList.insert(Use.getUser()); + + AO->mutateType(AO->getPointerOperand()->getType()); + AO->replaceAllUsesWith(AO->getPointerOperand()); + DeadUsers.insert(AO); + return; + } + + if (AddrSpaceCastInst *AC = dyn_cast(U)) { + for (auto &Use : AC->uses()) + WorkList.insert(Use.getUser()); + + AC->mutateType(AC->getPointerOperand()->getType()); + AC->replaceAllUsesWith(AC->getPointerOperand()); + return; + } + + PointerType *NewType = nullptr; + for (Use &U : U->operands()) { + PointerType *PT = dyn_cast(U.get()->getType()); + if (!PT) + continue; + + if (NewType == nullptr) + NewType = PT; + else { + // We could imagine a function calls taking 2 pointers to distinct + // address spaces which returns a pointer. But we want to run this + // pass after inlining, so we'll assume this doesn't happen. + assert(NewType->getAddressSpace() == PT->getAddressSpace()); + } + } + + assert(NewType != nullptr); + U->mutateType(NewType); + } + + virtual bool runOnFunction(Function &F) override { + const SPIRVSubtarget &ST = TM->getSubtarget(F); + GR = ST.getSPIRVGlobalRegistry(); + + DeadUsers.clear(); + gatherAddrspaceCast(F); + + while (WorkList.size() > 0) { + User *U = *WorkList.begin(); + WorkList.erase(U); + propagateAddrspace(U); + } + + for (User *U : DeadUsers) { + if (Instruction *I = dyn_cast(U)) + I->eraseFromParent(); + } + return DeadUsers.size() != 0; + } + +private: + SPIRVTargetMachine *TM = nullptr; + SPIRVGlobalRegistry *GR = nullptr; + std::unordered_set WorkList; + std::unordered_set DeadUsers; + +public: + static char ID; +}; + +char SPIRVLegalizeAddrspaceCast::ID = 0; +INITIALIZE_PASS(SPIRVLegalizeAddrspaceCast, "spirv-legalize-addrspacecast", + "SPIRV legalize addrspacecast", false, false) + +FunctionPass * +llvm::createSPIRVLegalizeAddrspaceCastPass(SPIRVTargetMachine *TM) { + return new SPIRVLegalizeAddrspaceCast(TM); +} diff --git a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp index 68286737b972f..b4c13d879b58e 100644 --- a/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVTargetMachine.cpp @@ -190,6 +190,8 @@ void SPIRVPassConfig::addIRPasses() { TargetPassConfig::addIRPasses(); if (TM.getSubtargetImpl()->isVulkanEnv()) { + addPass(createSPIRVLegalizeAddrspaceCastPass(&getTM())); + // 1. Simplify loop for subsequent transformations. After this steps, loops // have the following properties: // - loops have a single entry edge (pre-header to loop header). diff --git a/llvm/test/CodeGen/SPIRV/pointers/pointer-addrspacecast.ll b/llvm/test/CodeGen/SPIRV/pointers/pointer-addrspacecast.ll new file mode 100644 index 0000000000000..4d5549dfab8d9 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/pointer-addrspacecast.ll @@ -0,0 +1,36 @@ +; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %} + +; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0 +; CHECK-DAG: %[[#ptr_uint:]] = OpTypePointer Private %[[#uint]] +; CHECK-DAG: %[[#var:]] = OpVariable %[[#ptr_uint]] Private %[[#uint_0]] + +; CHECK-DAG: OpName %[[#func_simple:]] "simple" +; CHECK-DAG: OpName %[[#func_chain:]] "chain" + +@global = internal addrspace(10) global i32 zeroinitializer + +define void @simple() { +; CHECK: %[[#func_simple]] = OpFunction +entry: + %ptr = getelementptr i32, ptr addrspace(10) @global, i32 0 + %casted = addrspacecast ptr addrspace(10) %ptr to ptr + %val = load i32, ptr %casted +; CHECK: %{{.*}} = OpLoad %[[#uint]] %[[#var]] Aligned 4 + ret void +} + +define void @chain() { +; CHECK: %[[#func_chain]] = OpFunction +entry: + %a = getelementptr i32, ptr addrspace(10) @global, i32 0 + %b = addrspacecast ptr addrspace(10) %a to ptr + %c = getelementptr i32, ptr %b, i32 0 + %d = addrspacecast ptr %c to ptr addrspace(10) + %e = addrspacecast ptr addrspace(10) %d to ptr + + %val = load i32, ptr %e +; CHECK: %{{.*}} = OpLoad %[[#uint]] %[[#var]] Aligned 4 + ret void +} diff --git a/llvm/test/CodeGen/SPIRV/pointers/resource-addrspacecast-2.ll b/llvm/test/CodeGen/SPIRV/pointers/resource-addrspacecast-2.ll new file mode 100644 index 0000000000000..93208c16ed4a5 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/resource-addrspacecast-2.ll @@ -0,0 +1,54 @@ +; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - | FileCheck %s --match-full-lines +; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %} + +; FIXME(134119): enable-this once Offset decoration are added. +; XFAIL: spirv-tools + +%S2 = type { { [10 x { i32, i32 } ] }, i32 } + +; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0 +; CHECK-DAG: %[[#uint_1:]] = OpConstant %[[#uint]] 1 +; CHECK-DAG: %[[#uint_3:]] = OpConstant %[[#uint]] 3 +; CHECK-DAG: %[[#uint_10:]] = OpConstant %[[#uint]] 10 +; CHECK-DAG: %[[#uint_11:]] = OpConstant %[[#uint]] 11 +; CHECK-DAG: %[[#ptr_StorageBuffer_uint:]] = OpTypePointer StorageBuffer %[[#uint]] + +; CHECK-DAG: %[[#t_s2_s_a_s:]] = OpTypeStruct %[[#uint]] %[[#uint]] +; CHECK-DAG: %[[#t_s2_s_a:]] = OpTypeArray %[[#t_s2_s_a_s]] %[[#uint_10]] +; CHECK-DAG: %[[#t_s2_s:]] = OpTypeStruct %[[#t_s2_s_a]] +; CHECK-DAG: %[[#t_s2:]] = OpTypeStruct %[[#t_s2_s]] %[[#uint]] + +; CHECK-DAG: %[[#ptr_StorageBuffer_struct:]] = OpTypePointer StorageBuffer %[[#t_s2]] +; CHECK-DAG: %[[#rarr:]] = OpTypeRuntimeArray %[[#t_s2]] +; CHECK-DAG: %[[#rarr_struct:]] = OpTypeStruct %[[#rarr]] +; CHECK-DAG: %[[#spirv_VulkanBuffer:]] = OpTypePointer StorageBuffer %[[#rarr_struct]] + +declare target("spirv.VulkanBuffer", [0 x %S2], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_Ss_12_1t(i32, i32, i32, i32, i1) + +define void @main() "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" { +entry: + %handle = tail call target("spirv.VulkanBuffer", [0 x %S2], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_Ss_12_1t(i32 0, i32 0, i32 1, i32 0, i1 false) +; CHECK: %[[#resource:]] = OpVariable %[[#spirv_VulkanBuffer]] StorageBuffer + + %ptr = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_Ss_12_1t(target("spirv.VulkanBuffer", [0 x %S2], 12, 1) %handle, i32 0) +; CHECK: %[[#a:]] = OpCopyObject %[[#spirv_VulkanBuffer]] %[[#resource]] +; CHECK: %[[#b:]] = OpAccessChain %[[#ptr_StorageBuffer_struct]] %[[#a:]] %[[#uint_0]] %[[#uint_0]] + %casted = addrspacecast ptr addrspace(11) %ptr to ptr + +; CHECK: %[[#ptr2:]] = OpInBoundsAccessChain %[[#ptr_StorageBuffer_uint]] %[[#b:]] %[[#uint_0]] %[[#uint_0]] %[[#uint_3]] %[[#uint_1]] + %ptr2 = getelementptr inbounds %S2, ptr %casted, i64 0, i32 0, i32 0, i32 3, i32 1 + +; CHECK: OpStore %[[#ptr2]] %[[#uint_10]] Aligned 4 + store i32 10, ptr %ptr2, align 4 + +; Another store, but this time using LLVM's ability to load the first element +; without an explicit GEP. The backend has to determine the ptr type and +; generate the appropriate access chain. +; CHECK: %[[#ptr3:]] = OpInBoundsAccessChain %[[#ptr_StorageBuffer_uint]] %[[#b:]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]] %[[#uint_0]] +; CHECK: OpStore %[[#ptr3]] %[[#uint_11]] Aligned 4 + store i32 11, ptr %casted, align 4 + ret void +} + +declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_S2s_12_1t(target("spirv.VulkanBuffer", [0 x %S2], 12, 1), i32) diff --git a/llvm/test/CodeGen/SPIRV/pointers/resource-addrspacecast.ll b/llvm/test/CodeGen/SPIRV/pointers/resource-addrspacecast.ll new file mode 100644 index 0000000000000..24a50c7177340 --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/pointers/resource-addrspacecast.ll @@ -0,0 +1,37 @@ +; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %} + +; FIXME(134119): enable-this once Offset decoration are added. +; XFAIL: spirv-tools + +%struct.S = type { i32 } + +; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0 +; CHECK-DAG: %[[#uint_10:]] = OpConstant %[[#uint]] 10 +; CHECK-DAG: %[[#ptr_StorageBuffer_uint:]] = OpTypePointer StorageBuffer %[[#uint]] +; CHECK-DAG: %[[#struct:]] = OpTypeStruct %[[#uint]] +; CHECK-DAG: %[[#ptr_StorageBuffer_struct:]] = OpTypePointer StorageBuffer %[[#struct]] +; CHECK-DAG: %[[#rarr:]] = OpTypeRuntimeArray %[[#struct]] +; CHECK-DAG: %[[#rarr_struct:]] = OpTypeStruct %[[#rarr]] +; CHECK-DAG: %[[#spirv_VulkanBuffer:]] = OpTypePointer StorageBuffer %[[#rarr_struct]] + +declare target("spirv.VulkanBuffer", [0 x %struct.S], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.Ss_12_1t(i32, i32, i32, i32, i1) + +define void @main() "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" { +entry: + %handle = tail call target("spirv.VulkanBuffer", [0 x %struct.S], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.Ss_12_1t(i32 0, i32 0, i32 1, i32 0, i1 false) +; CHECK: %[[#resource:]] = OpVariable %[[#spirv_VulkanBuffer]] StorageBuffer + + %ptr = tail call noundef align 4 dereferenceable(4) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.Ss_12_1t(target("spirv.VulkanBuffer", [0 x %struct.S], 12, 1) %handle, i32 0) +; CHECK: %[[#a:]] = OpCopyObject %[[#spirv_VulkanBuffer]] %[[#resource]] +; CHECK: %[[#b:]] = OpAccessChain %[[#ptr_StorageBuffer_struct]] %[[#a:]] %[[#uint_0]] %[[#uint_0]] +; CHECK: %[[#c:]] = OpInBoundsAccessChain %[[#ptr_StorageBuffer_uint]] %[[#b:]] %[[#uint_0]] + %casted = addrspacecast ptr addrspace(11) %ptr to ptr + +; CHECK: OpStore %[[#c]] %[[#uint_10]] Aligned 4 + store i32 10, ptr %casted, align 4 + ret void +} + +declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.Ss_12_1t(target("spirv.VulkanBuffer", [0 x %struct.S], 12, 1), i32)