Skip to content

Conversation

@bogner
Copy link
Contributor

@bogner bogner commented Nov 19, 2024

This pass transforms resource access via llvm.dx.resource.getpointer into buffer loads and stores.

Fixes #114848.

@llvmbot llvmbot added backend:DirectX llvm:ir llvm:analysis Includes value tracking, cost tables and constant folding llvm:transforms labels Nov 19, 2024
@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-llvm-analysis
@llvm/pr-subscribers-llvm-transforms

@llvm/pr-subscribers-llvm-ir

Author: Justin Bogner (bogner)

Changes

This pass transforms resource access via llvm.dx.resource.getpointer into buffer loads and stores.

Fixes #114848.


Patch is 24.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116726.diff

13 Files Affected:

  • (modified) llvm/include/llvm/Analysis/DXILResource.h (+3)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+3)
  • (modified) llvm/lib/Analysis/DXILResource.cpp (+6)
  • (modified) llvm/lib/Target/DirectX/CMakeLists.txt (+1)
  • (added) llvm/lib/Target/DirectX/DXILResourceAccess.cpp (+196)
  • (added) llvm/lib/Target/DirectX/DXILResourceAccess.h (+28)
  • (modified) llvm/lib/Target/DirectX/DirectX.h (+7)
  • (modified) llvm/lib/Target/DirectX/DirectXPassRegistry.def (+6)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetMachine.cpp (+4-1)
  • (modified) llvm/lib/Transforms/Scalar/Scalarizer.cpp (+3)
  • (added) llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll (+35)
  • (added) llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll (+103)
  • (modified) llvm/test/CodeGen/DirectX/llc-pipeline.ll (+2-2)
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 6b577c02f05450..cd2ea3e07ee5b5 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -275,6 +275,9 @@ class DXILResourceMap {
   DXILResourceMap(
       SmallVectorImpl<std::pair<CallInst *, dxil::ResourceInfo>> &&CIToRI);
 
+  bool invalidate(Module &M, const PreservedAnalyses &PA,
+                  ModuleAnalysisManager::Invalidator &Inv);
+
   iterator begin() { return Resources.begin(); }
   const_iterator begin() const { return Resources.begin(); }
   iterator end() { return Resources.end(); }
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 48a9595f844f05..0d324f541d7663 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -27,6 +27,9 @@ def int_dx_handle_fromBinding
           [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
           [IntrNoMem]>;
 
+def int_dx_resource_getpointer
+    : DefaultAttrsIntrinsic<[llvm_anyptr_ty], [llvm_any_ty, llvm_i32_ty],
+                            [IntrNoMem]>;
 def int_dx_typedBufferLoad
     : DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty],
                             [IntrReadMem]>;
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index 2802480481690d..44909376928d65 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -744,6 +744,12 @@ DXILResourceMap::DXILResourceMap(
   }
 }
 
+bool DXILResourceMap::invalidate(Module &M, const PreservedAnalyses &PA,
+                                 ModuleAnalysisManager::Invalidator &Inv) {
+  auto PAC = PA.getChecker<DXILResourceAnalysis>();
+  return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Module>>());
+}
+
 void DXILResourceMap::print(raw_ostream &OS) const {
   for (unsigned I = 0, E = Resources.size(); I != E; ++I) {
     OS << "Binding " << I << ":\n";
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index a726071e0dcecd..26315db891b577 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -30,6 +30,7 @@ add_llvm_target(DirectXCodeGen
   DXILPrettyPrinter.cpp
   DXILResource.cpp
   DXILResourceAnalysis.cpp
+  DXILResourceAccess.cpp
   DXILShaderFlags.cpp
   DXILTranslateMetadata.cpp
 
diff --git a/llvm/lib/Target/DirectX/DXILResourceAccess.cpp b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp
new file mode 100644
index 00000000000000..f9b28800b74909
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp
@@ -0,0 +1,196 @@
+//===- DXILResourceAccess.cpp - Resource access via load/store ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILResourceAccess.h"
+#include "DirectX.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
+#include "llvm/InitializePasses.h"
+
+#define DEBUG_TYPE "dxil-resource-access"
+
+using namespace llvm;
+
+static void replaceTypedBufferAccess(IntrinsicInst *II,
+                                     dxil::ResourceInfo &RI) {
+  const DataLayout &DL = II->getDataLayout();
+
+  auto *HandleType = cast<TargetExtType>(II->getOperand(0)->getType());
+  assert(HandleType->getName() == "dx.TypedBuffer" &&
+         "Unexpected typed buffer type");
+  Type *ContainedType = HandleType->getTypeParameter(0);
+  Type *ScalarType = ContainedType->getScalarType();
+  uint64_t ScalarSize = DL.getTypeSizeInBits(ScalarType) / 8;
+  int NumElements = ContainedType->getNumContainedTypes();
+  if (!NumElements)
+    NumElements = 1;
+
+  // Process users keeping track of indexing accumulated from GEPs.
+  struct AccessAndIndex {
+    User *Access;
+    Value *Index;
+  };
+  SmallVector<AccessAndIndex> Worklist;
+  for (User *U : II->users())
+    Worklist.push_back({U, nullptr});
+
+  SmallVector<Instruction *> DeadInsts;
+  while (!Worklist.empty()) {
+    AccessAndIndex Current = Worklist.back();
+    Worklist.pop_back();
+
+    if (auto *GEP = dyn_cast<GetElementPtrInst>(Current.Access)) {
+      IRBuilder<> Builder(GEP);
+
+      Value *Index;
+      APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
+      if (GEP->accumulateConstantOffset(DL, ConstantOffset)) {
+        APInt Scaled = ConstantOffset.udiv(ScalarSize);
+        Index = ConstantInt::get(Builder.getInt32Ty(), Scaled);
+      } else {
+        auto IndexIt = GEP->idx_begin();
+        assert(cast<ConstantInt>(IndexIt)->getZExtValue() == 0 &&
+               "GEP is not indexing through pointer");
+        ++IndexIt;
+        Index = *IndexIt;
+        assert(++IndexIt == GEP->idx_end() && "Too many indices in GEP");
+      }
+
+      for (User *U : GEP->users())
+        Worklist.push_back({U, Index});
+      DeadInsts.push_back(GEP);
+
+    } else if (auto *SI = dyn_cast<StoreInst>(Current.Access)) {
+      assert(SI->getValueOperand() != II && "Pointer escaped!");
+      IRBuilder<> Builder(SI);
+
+      Value *V = SI->getValueOperand();
+      if (V->getType() == ContainedType) {
+        // V is already the right type.
+      } else if (V->getType() == ScalarType) {
+        // We're storing a scalar, so we need to load the current value and only
+        // replace the relevant part.
+        auto *Load = Builder.CreateIntrinsic(
+            ContainedType, Intrinsic::dx_typedBufferLoad,
+            {II->getOperand(0), II->getOperand(1)});
+        // If we have an offset from seeing a GEP earlier, use it.
+        Value *IndexOp = Current.Index
+                             ? Current.Index
+                             : ConstantInt::get(Builder.getInt32Ty(), 0);
+        V = Builder.CreateInsertElement(Load, V, IndexOp);
+      } else {
+        llvm_unreachable("Store to typed resource has invalid type");
+      }
+
+      auto *Inst = Builder.CreateIntrinsic(
+          Builder.getVoidTy(), Intrinsic::dx_typedBufferStore,
+          {II->getOperand(0), II->getOperand(1), V});
+      SI->replaceAllUsesWith(Inst);
+      DeadInsts.push_back(SI);
+
+    } else if (auto *LI = dyn_cast<LoadInst>(Current.Access)) {
+      IRBuilder<> Builder(LI);
+      Value *V =
+          Builder.CreateIntrinsic(ContainedType, Intrinsic::dx_typedBufferLoad,
+                                  {II->getOperand(0), II->getOperand(1)});
+      if (Current.Index)
+        V = Builder.CreateExtractElement(V, Current.Index);
+
+      LI->replaceAllUsesWith(V);
+      DeadInsts.push_back(LI);
+
+    } else
+      llvm_unreachable("Unhandled instruction - pointer escaped?");
+  }
+
+  // Traverse the now-dead instructions in RPO and remove them.
+  for (Instruction *Dead : llvm::reverse(DeadInsts))
+    Dead->eraseFromParent();
+  II->eraseFromParent();
+}
+
+static bool transformResourcePointers(Function &F, DXILResourceMap &DRM) {
+  // TODO: Should we have a more efficient way to find resources used in a
+  // particular function?
+  SmallVector<std::pair<IntrinsicInst *, dxil::ResourceInfo &>> Resources;
+  for (BasicBlock &BB : F)
+    for (Instruction &I : BB)
+      if (auto *CI = dyn_cast<CallInst>(&I)) {
+        auto It = DRM.find(CI);
+        if (It == DRM.end())
+          continue;
+        for (User *U : CI->users())
+          if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U))
+            if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer)
+              Resources.emplace_back(II, *It);
+      }
+
+  for (const auto &[II, RI] : Resources) {
+    if (RI.isTyped())
+      replaceTypedBufferAccess(II, RI);
+
+    // TODO: handle other resource types. We should probably have an
+    // `unreachable` here once we've added support for all of them.
+  }
+
+  return false;
+}
+
+PreservedAnalyses DXILResourceAccess::run(Function &F,
+                                          FunctionAnalysisManager &FAM) {
+  auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
+  DXILResourceMap *DRM =
+      MAMProxy.getCachedResult<DXILResourceAnalysis>(*F.getParent());
+  assert(DRM && "DXILResourceAnalysis must be available");
+
+  bool MadeChanges = transformResourcePointers(F, *DRM);
+  if (!MadeChanges)
+    return PreservedAnalyses::all();
+
+  PreservedAnalyses PA;
+  PA.preserve<DXILResourceAnalysis>();
+  PA.preserve<DominatorTreeAnalysis>();
+  return PA;
+}
+
+namespace {
+class DXILResourceAccessLegacy : public FunctionPass {
+public:
+  bool runOnFunction(Function &F) override {
+    DXILResourceMap &DRM =
+        getAnalysis<DXILResourceWrapperPass>().getResourceMap();
+
+    return transformResourcePointers(F, DRM);
+  }
+  StringRef getPassName() const override { return "DXIL Resource Access"; }
+  DXILResourceAccessLegacy() : FunctionPass(ID) {}
+
+  static char ID; // Pass identification.
+  void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
+    AU.addRequired<DXILResourceWrapperPass>();
+    AU.addPreserved<DXILResourceWrapperPass>();
+    AU.addPreserved<DominatorTreeWrapperPass>();
+  }
+};
+char DXILResourceAccessLegacy::ID = 0;
+} // end anonymous namespace
+
+INITIALIZE_PASS_BEGIN(DXILResourceAccessLegacy, DEBUG_TYPE,
+                      "DXIL Resource Access", false, false)
+INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
+INITIALIZE_PASS_END(DXILResourceAccessLegacy, DEBUG_TYPE,
+                    "DXIL Resource Access", false, false)
+
+FunctionPass *llvm::createDXILResourceAccessLegacyPass() {
+  return new DXILResourceAccessLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/DXILResourceAccess.h b/llvm/lib/Target/DirectX/DXILResourceAccess.h
new file mode 100644
index 00000000000000..ac47db21266f64
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILResourceAccess.h
@@ -0,0 +1,28 @@
+//===- DXILResourceAccess.h - Resource access via load/store ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file Pass for replacing pointers to DXIL resources with load and store
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
+#define LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class DXILResourceAccess: public PassInfoMixin<DXILResourceAccess> {
+public:
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
+} // namespace llvm
+
+#endif // LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 3454f16ecd5955..add23587de7d58 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -12,6 +12,7 @@
 #define LLVM_LIB_TARGET_DIRECTX_DIRECTX_H
 
 namespace llvm {
+class FunctionPass;
 class ModulePass;
 class PassRegistry;
 class raw_ostream;
@@ -52,6 +53,12 @@ void initializeDXILOpLoweringLegacyPass(PassRegistry &);
 /// Pass to lowering LLVM intrinsic call to DXIL op function call.
 ModulePass *createDXILOpLoweringLegacyPass();
 
+/// Initializer for DXILResourceAccess
+void initializeDXILResourceAccessLegacyPass(PassRegistry &);
+
+/// Pass to update resource accesses to use load/store directly.
+FunctionPass *createDXILResourceAccessLegacyPass();
+
 /// Initializer for DXILTranslateMetadata.
 void initializeDXILTranslateMetadataLegacyPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index a0f864ed39375f..87591b104ce52c 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -32,3 +32,9 @@ MODULE_PASS("dxil-translate-metadata", DXILTranslateMetadata())
 // TODO: rename to print<foo> after NPM switch
 MODULE_PASS("print-dx-shader-flags", dxil::ShaderFlagsAnalysisPrinter(dbgs()))
 #undef MODULE_PASS
+
+#ifndef FUNCTION_PASS
+#define FUNCTION_PASS(NAME, CREATE_PASS)
+#endif
+FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
+#undef FUNCTION_PASS
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 655427a3e80209..9dade16ffe2732 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -17,6 +17,7 @@
 #include "DXILIntrinsicExpansion.h"
 #include "DXILOpLowering.h"
 #include "DXILPrettyPrinter.h"
+#include "DXILResourceAccess.h"
 #include "DXILResourceAnalysis.h"
 #include "DXILShaderFlags.h"
 #include "DXILTranslateMetadata.h"
@@ -56,6 +57,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   initializeWriteDXILPassPass(*PR);
   initializeDXContainerGlobalsPass(*PR);
   initializeDXILOpLoweringLegacyPass(*PR);
+  initializeDXILResourceAccessLegacyPass(*PR);
   initializeDXILTranslateMetadataLegacyPass(*PR);
   initializeDXILResourceMDWrapperPass(*PR);
   initializeShaderFlagsAnalysisWrapperPass(*PR);
@@ -91,9 +93,10 @@ class DirectXPassConfig : public TargetPassConfig {
   void addCodeGenPrepare() override {
     addPass(createDXILIntrinsicExpansionLegacyPass());
     addPass(createDXILDataScalarizationLegacyPass());
+    addPass(createDXILFlattenArraysLegacyPass());
+    addPass(createDXILResourceAccessLegacyPass());
     ScalarizerPassOptions DxilScalarOptions;
     DxilScalarOptions.ScalarizeLoadStore = true;
-    addPass(createDXILFlattenArraysLegacyPass());
     addPass(createScalarizerPass(DxilScalarOptions));
     addPass(createDXILOpLoweringLegacyPass());
     addPass(createDXILFinalizeLinkageLegacyPass());
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 03d069c9fcb36d..9341bc8bc02de6 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -18,6 +18,7 @@
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/Argument.h"
@@ -351,6 +352,7 @@ void ScalarizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.addRequired<DominatorTreeWrapperPass>();
   AU.addRequired<TargetTransformInfoWrapperPass>();
   AU.addPreserved<DominatorTreeWrapperPass>();
+  AU.addPreserved<DXILResourceWrapperPass>();
 }
 
 char ScalarizerLegacyPass::ID = 0;
@@ -1348,5 +1350,6 @@ PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM)
   bool Changed = Impl.visit(F);
   PreservedAnalyses PA;
   PA.preserve<DominatorTreeAnalysis>();
+  PA.preserve<DXILResourceAnalysis>();
   return Changed ? PA : PreservedAnalyses::all();
 }
diff --git a/llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll b/llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll
new file mode 100644
index 00000000000000..2c17ec674632ba
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll
@@ -0,0 +1,35 @@
+; RUN: opt -S -dxil-resource-access %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+declare void @use_float4(<4 x float>)
+declare void @use_float(<4 x float>)
+
+; CHECK-LABEL: define void @load_float4
+define void @load_float4(i32 %index, i32 %elemindex) {
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK-NOT: @llvm.dx.resource.getpointer
+  %ptr = call ptr @llvm.dx.resource.getpointer(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+
+  ; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  %vec_data = load <4 x float>, ptr %ptr
+  call void @use_float4(<4 x float> %vec_data)
+
+  ; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: extractelement <4 x float> %[[VALUE]], i32 4
+  %y_ptr = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 4
+  %y_data = load float, ptr %y_ptr
+  call void @use_float(float %y_data)
+
+  ; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: extractelement <4 x float> %[[VALUE]], i32 %elemindex
+  %dynamic = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 %elemindex
+  %dyndata = load float, ptr %dynamic
+  call void @use_float(float %dyndata)
+
+  ret void
+}
diff --git a/llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll b/llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll
new file mode 100644
index 00000000000000..dd63acc3c0e96c
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll
@@ -0,0 +1,103 @@
+; RUN: opt -S -dxil-resource-access %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+; CHECK-LABEL: define void @store_float4
+define void @store_float4(<4 x float> %data, i32 %index, i32 %elemindex) {
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK-NOT: @llvm.dx.resource.getpointer
+  %ptr = call ptr @llvm.dx.resource.getpointer(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+
+  ; Store the whole value
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %data)
+  store <4 x float> %data, ptr %ptr
+
+  ; Store just the .x component
+  %scalar = extractelement <4 x float> %data, i32 0
+  ; CHECK: %[[LOAD:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: %[[INSERT:.*]] = insertelement <4 x float> %[[LOAD]], float %scalar, i32 0
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %[[INSERT]])
+  store float %scalar, ptr %ptr
+
+  ; Store just the .y component
+  ; CHECK: %[[LOAD:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: %[[INSERT:.*]] = insertelement <4 x float> %[[LOAD]], float %scalar, i32 1
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %[[INSERT]])
+  %y_ptr = getelementptr inbounds i8, ptr %ptr, i32 4
+  store float %scalar, ptr %y_ptr
+
+  ; Store to one of the elements dynamically
+  ; CHECK: %[[LOAD:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: %[[INSERT:.*]] = insertelement <4 x float> %[[LOAD]], float %scalar, i32 %elemindex
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %[[INSERT]])
+  %dynamic = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 %elemindex
+  store float %scalar, ptr %dynamic
+
+  ret void
+}
+
+; CHECK-LABEL: define void @store_half4
+define void @store_half4(<4 x half> %data, i32 %index) {
+  %buffer = call target("dx.TypedBuffer", <4 x half>, 1, 0, 0)
+      @llvm....
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Nov 19, 2024

@llvm/pr-subscribers-backend-directx

Author: Justin Bogner (bogner)

Changes

This pass transforms resource access via llvm.dx.resource.getpointer into buffer loads and stores.

Fixes #114848.


Patch is 24.25 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/116726.diff

13 Files Affected:

  • (modified) llvm/include/llvm/Analysis/DXILResource.h (+3)
  • (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+3)
  • (modified) llvm/lib/Analysis/DXILResource.cpp (+6)
  • (modified) llvm/lib/Target/DirectX/CMakeLists.txt (+1)
  • (added) llvm/lib/Target/DirectX/DXILResourceAccess.cpp (+196)
  • (added) llvm/lib/Target/DirectX/DXILResourceAccess.h (+28)
  • (modified) llvm/lib/Target/DirectX/DirectX.h (+7)
  • (modified) llvm/lib/Target/DirectX/DirectXPassRegistry.def (+6)
  • (modified) llvm/lib/Target/DirectX/DirectXTargetMachine.cpp (+4-1)
  • (modified) llvm/lib/Transforms/Scalar/Scalarizer.cpp (+3)
  • (added) llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll (+35)
  • (added) llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll (+103)
  • (modified) llvm/test/CodeGen/DirectX/llc-pipeline.ll (+2-2)
diff --git a/llvm/include/llvm/Analysis/DXILResource.h b/llvm/include/llvm/Analysis/DXILResource.h
index 6b577c02f05450..cd2ea3e07ee5b5 100644
--- a/llvm/include/llvm/Analysis/DXILResource.h
+++ b/llvm/include/llvm/Analysis/DXILResource.h
@@ -275,6 +275,9 @@ class DXILResourceMap {
   DXILResourceMap(
       SmallVectorImpl<std::pair<CallInst *, dxil::ResourceInfo>> &&CIToRI);
 
+  bool invalidate(Module &M, const PreservedAnalyses &PA,
+                  ModuleAnalysisManager::Invalidator &Inv);
+
   iterator begin() { return Resources.begin(); }
   const_iterator begin() const { return Resources.begin(); }
   iterator end() { return Resources.end(); }
diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td
index 48a9595f844f05..0d324f541d7663 100644
--- a/llvm/include/llvm/IR/IntrinsicsDirectX.td
+++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td
@@ -27,6 +27,9 @@ def int_dx_handle_fromBinding
           [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i32_ty, llvm_i1_ty],
           [IntrNoMem]>;
 
+def int_dx_resource_getpointer
+    : DefaultAttrsIntrinsic<[llvm_anyptr_ty], [llvm_any_ty, llvm_i32_ty],
+                            [IntrNoMem]>;
 def int_dx_typedBufferLoad
     : DefaultAttrsIntrinsic<[llvm_any_ty], [llvm_any_ty, llvm_i32_ty],
                             [IntrReadMem]>;
diff --git a/llvm/lib/Analysis/DXILResource.cpp b/llvm/lib/Analysis/DXILResource.cpp
index 2802480481690d..44909376928d65 100644
--- a/llvm/lib/Analysis/DXILResource.cpp
+++ b/llvm/lib/Analysis/DXILResource.cpp
@@ -744,6 +744,12 @@ DXILResourceMap::DXILResourceMap(
   }
 }
 
+bool DXILResourceMap::invalidate(Module &M, const PreservedAnalyses &PA,
+                                 ModuleAnalysisManager::Invalidator &Inv) {
+  auto PAC = PA.getChecker<DXILResourceAnalysis>();
+  return !(PAC.preserved() || PAC.preservedSet<AllAnalysesOn<Module>>());
+}
+
 void DXILResourceMap::print(raw_ostream &OS) const {
   for (unsigned I = 0, E = Resources.size(); I != E; ++I) {
     OS << "Binding " << I << ":\n";
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index a726071e0dcecd..26315db891b577 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -30,6 +30,7 @@ add_llvm_target(DirectXCodeGen
   DXILPrettyPrinter.cpp
   DXILResource.cpp
   DXILResourceAnalysis.cpp
+  DXILResourceAccess.cpp
   DXILShaderFlags.cpp
   DXILTranslateMetadata.cpp
 
diff --git a/llvm/lib/Target/DirectX/DXILResourceAccess.cpp b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp
new file mode 100644
index 00000000000000..f9b28800b74909
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILResourceAccess.cpp
@@ -0,0 +1,196 @@
+//===- DXILResourceAccess.cpp - Resource access via load/store ------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILResourceAccess.h"
+#include "DirectX.h"
+#include "llvm/Analysis/DXILResource.h"
+#include "llvm/IR/Dominators.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsDirectX.h"
+#include "llvm/InitializePasses.h"
+
+#define DEBUG_TYPE "dxil-resource-access"
+
+using namespace llvm;
+
+static void replaceTypedBufferAccess(IntrinsicInst *II,
+                                     dxil::ResourceInfo &RI) {
+  const DataLayout &DL = II->getDataLayout();
+
+  auto *HandleType = cast<TargetExtType>(II->getOperand(0)->getType());
+  assert(HandleType->getName() == "dx.TypedBuffer" &&
+         "Unexpected typed buffer type");
+  Type *ContainedType = HandleType->getTypeParameter(0);
+  Type *ScalarType = ContainedType->getScalarType();
+  uint64_t ScalarSize = DL.getTypeSizeInBits(ScalarType) / 8;
+  int NumElements = ContainedType->getNumContainedTypes();
+  if (!NumElements)
+    NumElements = 1;
+
+  // Process users keeping track of indexing accumulated from GEPs.
+  struct AccessAndIndex {
+    User *Access;
+    Value *Index;
+  };
+  SmallVector<AccessAndIndex> Worklist;
+  for (User *U : II->users())
+    Worklist.push_back({U, nullptr});
+
+  SmallVector<Instruction *> DeadInsts;
+  while (!Worklist.empty()) {
+    AccessAndIndex Current = Worklist.back();
+    Worklist.pop_back();
+
+    if (auto *GEP = dyn_cast<GetElementPtrInst>(Current.Access)) {
+      IRBuilder<> Builder(GEP);
+
+      Value *Index;
+      APInt ConstantOffset(DL.getIndexTypeSizeInBits(GEP->getType()), 0);
+      if (GEP->accumulateConstantOffset(DL, ConstantOffset)) {
+        APInt Scaled = ConstantOffset.udiv(ScalarSize);
+        Index = ConstantInt::get(Builder.getInt32Ty(), Scaled);
+      } else {
+        auto IndexIt = GEP->idx_begin();
+        assert(cast<ConstantInt>(IndexIt)->getZExtValue() == 0 &&
+               "GEP is not indexing through pointer");
+        ++IndexIt;
+        Index = *IndexIt;
+        assert(++IndexIt == GEP->idx_end() && "Too many indices in GEP");
+      }
+
+      for (User *U : GEP->users())
+        Worklist.push_back({U, Index});
+      DeadInsts.push_back(GEP);
+
+    } else if (auto *SI = dyn_cast<StoreInst>(Current.Access)) {
+      assert(SI->getValueOperand() != II && "Pointer escaped!");
+      IRBuilder<> Builder(SI);
+
+      Value *V = SI->getValueOperand();
+      if (V->getType() == ContainedType) {
+        // V is already the right type.
+      } else if (V->getType() == ScalarType) {
+        // We're storing a scalar, so we need to load the current value and only
+        // replace the relevant part.
+        auto *Load = Builder.CreateIntrinsic(
+            ContainedType, Intrinsic::dx_typedBufferLoad,
+            {II->getOperand(0), II->getOperand(1)});
+        // If we have an offset from seeing a GEP earlier, use it.
+        Value *IndexOp = Current.Index
+                             ? Current.Index
+                             : ConstantInt::get(Builder.getInt32Ty(), 0);
+        V = Builder.CreateInsertElement(Load, V, IndexOp);
+      } else {
+        llvm_unreachable("Store to typed resource has invalid type");
+      }
+
+      auto *Inst = Builder.CreateIntrinsic(
+          Builder.getVoidTy(), Intrinsic::dx_typedBufferStore,
+          {II->getOperand(0), II->getOperand(1), V});
+      SI->replaceAllUsesWith(Inst);
+      DeadInsts.push_back(SI);
+
+    } else if (auto *LI = dyn_cast<LoadInst>(Current.Access)) {
+      IRBuilder<> Builder(LI);
+      Value *V =
+          Builder.CreateIntrinsic(ContainedType, Intrinsic::dx_typedBufferLoad,
+                                  {II->getOperand(0), II->getOperand(1)});
+      if (Current.Index)
+        V = Builder.CreateExtractElement(V, Current.Index);
+
+      LI->replaceAllUsesWith(V);
+      DeadInsts.push_back(LI);
+
+    } else
+      llvm_unreachable("Unhandled instruction - pointer escaped?");
+  }
+
+  // Traverse the now-dead instructions in RPO and remove them.
+  for (Instruction *Dead : llvm::reverse(DeadInsts))
+    Dead->eraseFromParent();
+  II->eraseFromParent();
+}
+
+static bool transformResourcePointers(Function &F, DXILResourceMap &DRM) {
+  // TODO: Should we have a more efficient way to find resources used in a
+  // particular function?
+  SmallVector<std::pair<IntrinsicInst *, dxil::ResourceInfo &>> Resources;
+  for (BasicBlock &BB : F)
+    for (Instruction &I : BB)
+      if (auto *CI = dyn_cast<CallInst>(&I)) {
+        auto It = DRM.find(CI);
+        if (It == DRM.end())
+          continue;
+        for (User *U : CI->users())
+          if (IntrinsicInst *II = dyn_cast<IntrinsicInst>(U))
+            if (II->getIntrinsicID() == Intrinsic::dx_resource_getpointer)
+              Resources.emplace_back(II, *It);
+      }
+
+  for (const auto &[II, RI] : Resources) {
+    if (RI.isTyped())
+      replaceTypedBufferAccess(II, RI);
+
+    // TODO: handle other resource types. We should probably have an
+    // `unreachable` here once we've added support for all of them.
+  }
+
+  return false;
+}
+
+PreservedAnalyses DXILResourceAccess::run(Function &F,
+                                          FunctionAnalysisManager &FAM) {
+  auto &MAMProxy = FAM.getResult<ModuleAnalysisManagerFunctionProxy>(F);
+  DXILResourceMap *DRM =
+      MAMProxy.getCachedResult<DXILResourceAnalysis>(*F.getParent());
+  assert(DRM && "DXILResourceAnalysis must be available");
+
+  bool MadeChanges = transformResourcePointers(F, *DRM);
+  if (!MadeChanges)
+    return PreservedAnalyses::all();
+
+  PreservedAnalyses PA;
+  PA.preserve<DXILResourceAnalysis>();
+  PA.preserve<DominatorTreeAnalysis>();
+  return PA;
+}
+
+namespace {
+class DXILResourceAccessLegacy : public FunctionPass {
+public:
+  bool runOnFunction(Function &F) override {
+    DXILResourceMap &DRM =
+        getAnalysis<DXILResourceWrapperPass>().getResourceMap();
+
+    return transformResourcePointers(F, DRM);
+  }
+  StringRef getPassName() const override { return "DXIL Resource Access"; }
+  DXILResourceAccessLegacy() : FunctionPass(ID) {}
+
+  static char ID; // Pass identification.
+  void getAnalysisUsage(llvm::AnalysisUsage &AU) const override {
+    AU.addRequired<DXILResourceWrapperPass>();
+    AU.addPreserved<DXILResourceWrapperPass>();
+    AU.addPreserved<DominatorTreeWrapperPass>();
+  }
+};
+char DXILResourceAccessLegacy::ID = 0;
+} // end anonymous namespace
+
+INITIALIZE_PASS_BEGIN(DXILResourceAccessLegacy, DEBUG_TYPE,
+                      "DXIL Resource Access", false, false)
+INITIALIZE_PASS_DEPENDENCY(DXILResourceWrapperPass)
+INITIALIZE_PASS_END(DXILResourceAccessLegacy, DEBUG_TYPE,
+                    "DXIL Resource Access", false, false)
+
+FunctionPass *llvm::createDXILResourceAccessLegacyPass() {
+  return new DXILResourceAccessLegacy();
+}
diff --git a/llvm/lib/Target/DirectX/DXILResourceAccess.h b/llvm/lib/Target/DirectX/DXILResourceAccess.h
new file mode 100644
index 00000000000000..ac47db21266f64
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILResourceAccess.h
@@ -0,0 +1,28 @@
+//===- DXILResourceAccess.h - Resource access via load/store ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// \file Pass for replacing pointers to DXIL resources with load and store
+// operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
+#define LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
+
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class DXILResourceAccess: public PassInfoMixin<DXILResourceAccess> {
+public:
+  PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
+} // namespace llvm
+
+#endif // LLVM_LIB_TARGET_DIRECTX_DXILRESOURCEACCESS_H
diff --git a/llvm/lib/Target/DirectX/DirectX.h b/llvm/lib/Target/DirectX/DirectX.h
index 3454f16ecd5955..add23587de7d58 100644
--- a/llvm/lib/Target/DirectX/DirectX.h
+++ b/llvm/lib/Target/DirectX/DirectX.h
@@ -12,6 +12,7 @@
 #define LLVM_LIB_TARGET_DIRECTX_DIRECTX_H
 
 namespace llvm {
+class FunctionPass;
 class ModulePass;
 class PassRegistry;
 class raw_ostream;
@@ -52,6 +53,12 @@ void initializeDXILOpLoweringLegacyPass(PassRegistry &);
 /// Pass to lowering LLVM intrinsic call to DXIL op function call.
 ModulePass *createDXILOpLoweringLegacyPass();
 
+/// Initializer for DXILResourceAccess
+void initializeDXILResourceAccessLegacyPass(PassRegistry &);
+
+/// Pass to update resource accesses to use load/store directly.
+FunctionPass *createDXILResourceAccessLegacyPass();
+
 /// Initializer for DXILTranslateMetadata.
 void initializeDXILTranslateMetadataLegacyPass(PassRegistry &);
 
diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index a0f864ed39375f..87591b104ce52c 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -32,3 +32,9 @@ MODULE_PASS("dxil-translate-metadata", DXILTranslateMetadata())
 // TODO: rename to print<foo> after NPM switch
 MODULE_PASS("print-dx-shader-flags", dxil::ShaderFlagsAnalysisPrinter(dbgs()))
 #undef MODULE_PASS
+
+#ifndef FUNCTION_PASS
+#define FUNCTION_PASS(NAME, CREATE_PASS)
+#endif
+FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
+#undef FUNCTION_PASS
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 655427a3e80209..9dade16ffe2732 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -17,6 +17,7 @@
 #include "DXILIntrinsicExpansion.h"
 #include "DXILOpLowering.h"
 #include "DXILPrettyPrinter.h"
+#include "DXILResourceAccess.h"
 #include "DXILResourceAnalysis.h"
 #include "DXILShaderFlags.h"
 #include "DXILTranslateMetadata.h"
@@ -56,6 +57,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
   initializeWriteDXILPassPass(*PR);
   initializeDXContainerGlobalsPass(*PR);
   initializeDXILOpLoweringLegacyPass(*PR);
+  initializeDXILResourceAccessLegacyPass(*PR);
   initializeDXILTranslateMetadataLegacyPass(*PR);
   initializeDXILResourceMDWrapperPass(*PR);
   initializeShaderFlagsAnalysisWrapperPass(*PR);
@@ -91,9 +93,10 @@ class DirectXPassConfig : public TargetPassConfig {
   void addCodeGenPrepare() override {
     addPass(createDXILIntrinsicExpansionLegacyPass());
     addPass(createDXILDataScalarizationLegacyPass());
+    addPass(createDXILFlattenArraysLegacyPass());
+    addPass(createDXILResourceAccessLegacyPass());
     ScalarizerPassOptions DxilScalarOptions;
     DxilScalarOptions.ScalarizeLoadStore = true;
-    addPass(createDXILFlattenArraysLegacyPass());
     addPass(createScalarizerPass(DxilScalarOptions));
     addPass(createDXILOpLoweringLegacyPass());
     addPass(createDXILFinalizeLinkageLegacyPass());
diff --git a/llvm/lib/Transforms/Scalar/Scalarizer.cpp b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
index 03d069c9fcb36d..9341bc8bc02de6 100644
--- a/llvm/lib/Transforms/Scalar/Scalarizer.cpp
+++ b/llvm/lib/Transforms/Scalar/Scalarizer.cpp
@@ -18,6 +18,7 @@
 #include "llvm/ADT/PostOrderIterator.h"
 #include "llvm/ADT/SmallVector.h"
 #include "llvm/ADT/Twine.h"
+#include "llvm/Analysis/DXILResource.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/Analysis/VectorUtils.h"
 #include "llvm/IR/Argument.h"
@@ -351,6 +352,7 @@ void ScalarizerLegacyPass::getAnalysisUsage(AnalysisUsage &AU) const {
   AU.addRequired<DominatorTreeWrapperPass>();
   AU.addRequired<TargetTransformInfoWrapperPass>();
   AU.addPreserved<DominatorTreeWrapperPass>();
+  AU.addPreserved<DXILResourceWrapperPass>();
 }
 
 char ScalarizerLegacyPass::ID = 0;
@@ -1348,5 +1350,6 @@ PreservedAnalyses ScalarizerPass::run(Function &F, FunctionAnalysisManager &AM)
   bool Changed = Impl.visit(F);
   PreservedAnalyses PA;
   PA.preserve<DominatorTreeAnalysis>();
+  PA.preserve<DXILResourceAnalysis>();
   return Changed ? PA : PreservedAnalyses::all();
 }
diff --git a/llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll b/llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll
new file mode 100644
index 00000000000000..2c17ec674632ba
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ResourceAccess/load_typedbuffer.ll
@@ -0,0 +1,35 @@
+; RUN: opt -S -dxil-resource-access %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+declare void @use_float4(<4 x float>)
+declare void @use_float(<4 x float>)
+
+; CHECK-LABEL: define void @load_float4
+define void @load_float4(i32 %index, i32 %elemindex) {
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK-NOT: @llvm.dx.resource.getpointer
+  %ptr = call ptr @llvm.dx.resource.getpointer(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+
+  ; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  %vec_data = load <4 x float>, ptr %ptr
+  call void @use_float4(<4 x float> %vec_data)
+
+  ; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: extractelement <4 x float> %[[VALUE]], i32 4
+  %y_ptr = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 4
+  %y_data = load float, ptr %y_ptr
+  call void @use_float(float %y_data)
+
+  ; CHECK: %[[VALUE:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: extractelement <4 x float> %[[VALUE]], i32 %elemindex
+  %dynamic = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 %elemindex
+  %dyndata = load float, ptr %dynamic
+  call void @use_float(float %dyndata)
+
+  ret void
+}
diff --git a/llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll b/llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll
new file mode 100644
index 00000000000000..dd63acc3c0e96c
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/ResourceAccess/store_typedbuffer.ll
@@ -0,0 +1,103 @@
+; RUN: opt -S -dxil-resource-access %s | FileCheck %s
+
+target triple = "dxil-pc-shadermodel6.6-compute"
+
+; CHECK-LABEL: define void @store_float4
+define void @store_float4(<4 x float> %data, i32 %index, i32 %elemindex) {
+  %buffer = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0)
+      @llvm.dx.handle.fromBinding.tdx.TypedBuffer_v4f32_1_0_0(
+          i32 0, i32 0, i32 1, i32 0, i1 false)
+
+  ; CHECK-NOT: @llvm.dx.resource.getpointer
+  %ptr = call ptr @llvm.dx.resource.getpointer(
+      target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+
+  ; Store the whole value
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %data)
+  store <4 x float> %data, ptr %ptr
+
+  ; Store just the .x component
+  %scalar = extractelement <4 x float> %data, i32 0
+  ; CHECK: %[[LOAD:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: %[[INSERT:.*]] = insertelement <4 x float> %[[LOAD]], float %scalar, i32 0
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %[[INSERT]])
+  store float %scalar, ptr %ptr
+
+  ; Store just the .y component
+  ; CHECK: %[[LOAD:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: %[[INSERT:.*]] = insertelement <4 x float> %[[LOAD]], float %scalar, i32 1
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %[[INSERT]])
+  %y_ptr = getelementptr inbounds i8, ptr %ptr, i32 4
+  store float %scalar, ptr %y_ptr
+
+  ; Store to one of the elements dynamically
+  ; CHECK: %[[LOAD:.*]] = call <4 x float> @llvm.dx.typedBufferLoad.v4f32.tdx.TypedBuffer_v4f32_1_0_0t(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index)
+  ; CHECK: %[[INSERT:.*]] = insertelement <4 x float> %[[LOAD]], float %scalar, i32 %elemindex
+  ; CHECK: call void @llvm.dx.typedBufferStore.tdx.TypedBuffer_v4f32_1_0_0t.v4f32(target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %buffer, i32 %index, <4 x float> %[[INSERT]])
+  %dynamic = getelementptr inbounds <4 x float>, ptr %ptr, i32 0, i32 %elemindex
+  store float %scalar, ptr %dynamic
+
+  ret void
+}
+
+; CHECK-LABEL: define void @store_half4
+define void @store_half4(<4 x half> %data, i32 %index) {
+  %buffer = call target("dx.TypedBuffer", <4 x half>, 1, 0, 0)
+      @llvm....
[truncated]

@github-actions
Copy link

github-actions bot commented Nov 19, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Member

@hekota hekota left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Couple of questions about the tests, otherwise LGTM!

Copy link
Contributor

@damyanp damyanp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

bogner added a commit to bogner/llvm-project that referenced this pull request Nov 27, 2024
Including this removes the PR's dependency on llvm#116726
@bogner bogner force-pushed the 2024-11-18-resource-access branch from 765ddaf to f12a6bf Compare December 12, 2024 22:18
@bogner bogner changed the base branch from main to users/bogner/119773 December 12, 2024 22:29
@bogner bogner force-pushed the 2024-11-18-resource-access branch from f12a6bf to 958f7ec Compare December 12, 2024 22:49
@bogner bogner force-pushed the users/bogner/119773 branch 3 times, most recently from f8291d2 to 4f018d3 Compare December 16, 2024 23:15
@bogner bogner force-pushed the 2024-11-18-resource-access branch from 0ee6dbf to 2171296 Compare December 16, 2024 23:18
@bogner bogner changed the base branch from users/bogner/119773 to main December 18, 2024 16:03
@bogner bogner force-pushed the 2024-11-18-resource-access branch from 2171296 to 4666082 Compare December 18, 2024 16:06
@bogner bogner merged commit 0fca76d into llvm:main Dec 18, 2024
6 of 7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

backend:DirectX llvm:analysis Includes value tracking, cost tables and constant folding llvm:ir llvm:transforms

Projects

Archived in project

Development

Successfully merging this pull request may close these issues.

[DirectX] Replace resource accesses in the device address space to typedbuffer load and store

7 participants