Skip to content

Commit ca7c111

Browse files
committed
[DirectX] Implement the ForwardHandleAccesses pass
This pass attempts to forward resource handle creation to accesses of the handle global. This avoids dependence on optimizations like CSE and GlobalOpt for correctness of DXIL. Fixes #134574.
1 parent 1da856a commit ca7c111

File tree

13 files changed

+348
-1
lines changed

13 files changed

+348
-1
lines changed

llvm/include/llvm/Analysis/DXILResource.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,24 @@ class SamplerExtType : public TargetExtType {
196196
}
197197
};
198198

199+
class AnyResourceExtType : public TargetExtType {
200+
public:
201+
AnyResourceExtType() = delete;
202+
AnyResourceExtType(const AnyResourceExtType &) = delete;
203+
AnyResourceExtType &operator=(const AnyResourceExtType &) = delete;
204+
205+
static bool classof(const TargetExtType *T) {
206+
return isa<RawBufferExtType>(T) || isa<TypedBufferExtType>(T) ||
207+
isa<TextureExtType>(T) || isa<MSTextureExtType>(T) ||
208+
isa<FeedbackTextureExtType>(T) || isa<CBufferExtType>(T) ||
209+
isa<SamplerExtType>(T);
210+
}
211+
212+
static bool classof(const Type *T) {
213+
return isa<TargetExtType>(T) && classof(cast<TargetExtType>(T));
214+
}
215+
};
216+
199217
/// The dx.Layout target extension type
200218
///
201219
/// `target("dx.Layout", <Type>, <size>, [offsets...])`

llvm/lib/Target/DirectX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ add_llvm_target(DirectXCodeGen
2323
DXILCBufferAccess.cpp
2424
DXILDataScalarization.cpp
2525
DXILFinalizeLinkage.cpp
26+
DXILForwardHandleAccesses.cpp
2627
DXILFlattenArrays.cpp
2728
DXILIntrinsicExpansion.cpp
2829
DXILOpBuilder.cpp
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
//===- DXILForwardHandleAccesses.cpp - Cleanup Handles --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "DXILForwardHandleAccesses.h"
10+
#include "DXILShaderFlags.h"
11+
#include "DirectX.h"
12+
#include "llvm/Analysis/DXILResource.h"
13+
#include "llvm/Analysis/Loads.h"
14+
#include "llvm/IR/DiagnosticInfo.h"
15+
#include "llvm/IR/Dominators.h"
16+
#include "llvm/IR/IntrinsicInst.h"
17+
#include "llvm/IR/Intrinsics.h"
18+
#include "llvm/IR/IntrinsicsDirectX.h"
19+
#include "llvm/IR/Module.h"
20+
#include "llvm/InitializePasses.h"
21+
#include "llvm/Pass.h"
22+
#include "llvm/Transforms/Utils/Local.h"
23+
24+
#define DEBUG_TYPE "dxil-forward-handle-accesses"
25+
26+
using namespace llvm;
27+
28+
static void diagnoseAmbiguousHandle(IntrinsicInst *NewII,
29+
IntrinsicInst *PrevII) {
30+
Function *F = NewII->getFunction();
31+
LLVMContext &Context = F->getParent()->getContext();
32+
Context.diagnose(DiagnosticInfoGeneric(
33+
Twine("Handle at \"") + NewII->getName() + "\" overwrites handle at \"" +
34+
PrevII->getName() + "\""));
35+
}
36+
37+
static void diagnoseHandleNotFound(LoadInst *LI) {
38+
Function *F = LI->getFunction();
39+
LLVMContext &Context = F->getParent()->getContext();
40+
Context.diagnose(DiagnosticInfoGeneric(
41+
LI, Twine("Load of \"") + LI->getPointerOperand()->getName() +
42+
"\" is not a global resource handle"));
43+
}
44+
45+
static void diagnoseUndominatedLoad(LoadInst *LI, IntrinsicInst *Handle) {
46+
Function *F = LI->getFunction();
47+
LLVMContext &Context = F->getParent()->getContext();
48+
Context.diagnose(DiagnosticInfoGeneric(
49+
LI, Twine("Load at \"") + LI->getName() +
50+
"\" is not dominated by handle creation at \"" +
51+
Handle->getName() + "\""));
52+
}
53+
54+
static void
55+
processHandle(IntrinsicInst *II,
56+
DenseMap<GlobalVariable *, IntrinsicInst *> &HandleMap) {
57+
for (User *U : II->users())
58+
if (auto *SI = dyn_cast<StoreInst>(U))
59+
if (auto *GV = dyn_cast<GlobalVariable>(SI->getPointerOperand())) {
60+
auto Entry = HandleMap.try_emplace(GV, II);
61+
if (Entry.second)
62+
LLVM_DEBUG(dbgs() << "Added " << GV->getName() << " to handle map\n");
63+
else
64+
diagnoseAmbiguousHandle(II, Entry.first->second);
65+
}
66+
}
67+
68+
static bool forwardHandleAccesses(Function &F, DominatorTree &DT) {
69+
bool Changed = false;
70+
71+
DenseMap<GlobalVariable *, IntrinsicInst *> HandleMap;
72+
SmallVector<LoadInst *> LoadsToProcess;
73+
for (BasicBlock &BB : F)
74+
for (Instruction &Inst : BB)
75+
if (auto *II = dyn_cast<IntrinsicInst>(&Inst)) {
76+
switch (II->getIntrinsicID()) {
77+
case Intrinsic::dx_resource_handlefrombinding:
78+
processHandle(II, HandleMap);
79+
break;
80+
default:
81+
continue;
82+
}
83+
} else if (auto *LI = dyn_cast<LoadInst>(&Inst))
84+
if (isa<dxil::AnyResourceExtType>(LI->getType()))
85+
LoadsToProcess.push_back(LI);
86+
87+
for (LoadInst *LI : LoadsToProcess) {
88+
Value *V = LI->getPointerOperand();
89+
auto *GV = dyn_cast<GlobalVariable>(LI->getPointerOperand());
90+
91+
// If we didn't find the global, we may need to walk through a level of
92+
// indirection. This generally happens at -O0.
93+
if (!GV)
94+
if (auto *NestedLI = dyn_cast<LoadInst>(V)) {
95+
BasicBlock::iterator BBI(NestedLI);
96+
Value *Loaded = FindAvailableLoadedValue(
97+
NestedLI, NestedLI->getParent(), BBI, 0, nullptr, nullptr);
98+
GV = dyn_cast_or_null<GlobalVariable>(Loaded);
99+
}
100+
101+
auto It = HandleMap.find(GV);
102+
if (It == HandleMap.end()) {
103+
diagnoseHandleNotFound(LI);
104+
continue;
105+
}
106+
Changed = true;
107+
108+
if (!DT.dominates(It->second, LI)) {
109+
diagnoseUndominatedLoad(LI, It->second);
110+
continue;
111+
}
112+
113+
LLVM_DEBUG(dbgs() << "Replacing uses of " << GV->getName() << " at "
114+
<< LI->getName() << " with " << It->second->getName()
115+
<< "\n");
116+
LI->replaceAllUsesWith(It->second);
117+
LI->eraseFromParent();
118+
}
119+
120+
return Changed;
121+
}
122+
123+
PreservedAnalyses DXILForwardHandleAccesses::run(Function &F,
124+
FunctionAnalysisManager &AM) {
125+
PreservedAnalyses PA;
126+
127+
DominatorTree *DT = &AM.getResult<DominatorTreeAnalysis>(F);
128+
bool Changed = forwardHandleAccesses(F, *DT);
129+
130+
if (!Changed)
131+
return PreservedAnalyses::all();
132+
return PA;
133+
}
134+
135+
namespace {
136+
class DXILForwardHandleAccessesLegacy : public FunctionPass {
137+
public:
138+
bool runOnFunction(Function &F) override {
139+
DominatorTree *DT = &getAnalysis<DominatorTreeWrapperPass>().getDomTree();
140+
return forwardHandleAccesses(F, *DT);
141+
}
142+
StringRef getPassName() const override {
143+
return "DXIL Forward Handle Accesses";
144+
}
145+
146+
void getAnalysisUsage(AnalysisUsage &AU) const override {
147+
AU.addRequired<DominatorTreeWrapperPass>();
148+
}
149+
150+
DXILForwardHandleAccessesLegacy() : FunctionPass(ID) {}
151+
152+
static char ID; // Pass identification.
153+
};
154+
char DXILForwardHandleAccessesLegacy::ID = 0;
155+
} // end anonymous namespace
156+
157+
INITIALIZE_PASS_BEGIN(DXILForwardHandleAccessesLegacy, DEBUG_TYPE,
158+
"DXIL Forward Handle Accesses", false, false)
159+
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
160+
INITIALIZE_PASS_END(DXILForwardHandleAccessesLegacy, DEBUG_TYPE,
161+
"DXIL Forward Handle Accesses", false, false)
162+
163+
FunctionPass *llvm::createDXILForwardHandleAccessesLegacyPass() {
164+
return new DXILForwardHandleAccessesLegacy();
165+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===- DXILForwardHandleAccesses.h - Cleanup Handles ------------*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// \file Eliminate redundant stores and loads from handle globals.
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
#ifndef LLVM_LIB_TARGET_DIRECTX_DXILFORWARDHANDLEACCESS_H
14+
#define LLVM_LIB_TARGET_DIRECTX_DXILFORWARDHANDLEACCESS_H
15+
16+
#include "llvm/IR/PassManager.h"
17+
18+
namespace llvm {
19+
20+
class DXILForwardHandleAccesses
21+
: public PassInfoMixin<DXILForwardHandleAccesses> {
22+
public:
23+
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
24+
};
25+
26+
} // namespace llvm
27+
28+
#endif // LLVM_LIB_TARGET_DIRECTX_DXILFORWARDHANDLEACCESS_H

llvm/lib/Target/DirectX/DirectX.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ void initializeDXILFlattenArraysLegacyPass(PassRegistry &);
5353
/// Pass to flatten arrays into a one dimensional DXIL legal form
5454
ModulePass *createDXILFlattenArraysLegacyPass();
5555

56+
/// Initializer for DXIL Forward Handle Accesses Pass
57+
void initializeDXILForwardHandleAccessesLegacyPass(PassRegistry &);
58+
59+
/// Pass to eliminate redundant stores and loads from handle globals.
60+
FunctionPass *createDXILForwardHandleAccessesLegacyPass();
61+
5662
/// Initializer DXIL legalizationPass
5763
void initializeDXILLegalizeLegacyPass(PassRegistry &);
5864

llvm/lib/Target/DirectX/DirectXPassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ MODULE_PASS("print<dxil-root-signature>", dxil::RootSignatureAnalysisPrinter(dbg
3838
#ifndef FUNCTION_PASS
3939
#define FUNCTION_PASS(NAME, CREATE_PASS)
4040
#endif
41+
FUNCTION_PASS("dxil-forward-handle-accesses", DXILForwardHandleAccesses())
4142
FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
4243
FUNCTION_PASS("dxil-legalize", DXILLegalizePass())
4344
#undef FUNCTION_PASS

llvm/lib/Target/DirectX/DirectXTargetMachine.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "DXILCBufferAccess.h"
1616
#include "DXILDataScalarization.h"
1717
#include "DXILFlattenArrays.h"
18+
#include "DXILForwardHandleAccesses.h"
1819
#include "DXILIntrinsicExpansion.h"
1920
#include "DXILLegalizePass.h"
2021
#include "DXILOpLowering.h"
@@ -66,6 +67,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
6667
initializeRootSignatureAnalysisWrapperPass(*PR);
6768
initializeDXILFinalizeLinkageLegacyPass(*PR);
6869
initializeDXILPrettyPrinterLegacyPass(*PR);
70+
initializeDXILForwardHandleAccessesLegacyPass(*PR);
6971
initializeDXILCBufferAccessLegacyPass(*PR);
7072
}
7173

@@ -105,6 +107,7 @@ class DirectXPassConfig : public TargetPassConfig {
105107
ScalarizerPassOptions DxilScalarOptions;
106108
DxilScalarOptions.ScalarizeLoadStore = true;
107109
addPass(createScalarizerPass(DxilScalarOptions));
110+
addPass(createDXILForwardHandleAccessesLegacyPass());
108111
addPass(createDXILLegalizeLegacyPass());
109112
addPass(createDXILTranslateMetadataLegacyPass());
110113
addPass(createDXILOpLoweringLegacyPass());
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
; RUN: not opt -S -dxil-forward-handle-accesses -mtriple=dxil--shadermodel6.3-library %s 2>&1 | FileCheck %s
2+
3+
; CHECK: error: Load of "buf" is not a global resource handle
4+
5+
%"class.hlsl::RWStructuredBuffer" = type { target("dx.RawBuffer", <4 x float>, 1, 0) }
6+
@Buf = internal global %"class.hlsl::RWStructuredBuffer" poison, align 4
7+
8+
define float @f() {
9+
entry:
10+
%buf = alloca target("dx.RawBuffer", <4 x float>, 1, 0), align 4
11+
%h = call target("dx.RawBuffer", <4 x float>, 1, 0) @llvm.dx.resource.handlefrombinding(i32 0, i32 0, i32 1, i32 0, i1 false)
12+
store target("dx.RawBuffer", <4 x float>, 1, 0) %h, ptr %buf, align 4
13+
14+
%b = load target("dx.RawBuffer", <4 x float>, 1, 0), ptr %buf, align 4
15+
%l = call { <4 x float>, i1 } @llvm.dx.resource.load.rawbuffer(target("dx.RawBuffer", <4 x float>, 1, 0) %b, i32 0, i32 0)
16+
%x = extractvalue { <4 x float>, i1 } %l, 0
17+
%v = extractelement <4 x float> %x, i32 0
18+
19+
ret float %v
20+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
; RUN: not opt -S -dxil-forward-handle-accesses -mtriple=dxil--shadermodel6.3-library %s 2>&1 | FileCheck %s
2+
3+
; CHECK: error: Handle at "h2" overwrites handle at "h1"
4+
5+
%"class.hlsl::RWStructuredBuffer" = type { target("dx.RawBuffer", <4 x float>, 1, 0) }
6+
@Buf = internal global %"class.hlsl::RWStructuredBuffer" poison, align 4
7+
8+
define float @f() {
9+
entry:
10+
%h1 = call target("dx.RawBuffer", <4 x float>, 1, 0) @llvm.dx.resource.handlefrombinding(i32 0, i32 0, i32 1, i32 0, i1 false)
11+
store target("dx.RawBuffer", <4 x float>, 1, 0) %h1, ptr @Buf, align 4
12+
%h2 = call target("dx.RawBuffer", <4 x float>, 1, 0) @llvm.dx.resource.handlefrombinding(i32 0, i32 1, i32 1, i32 0, i1 false)
13+
store target("dx.RawBuffer", <4 x float>, 1, 0) %h2, ptr @Buf, align 4
14+
15+
%b = load target("dx.RawBuffer", <4 x float>, 1, 0), ptr @Buf, align 4
16+
%l = call { <4 x float>, i1 } @llvm.dx.resource.load.rawbuffer(target("dx.RawBuffer", <4 x float>, 1, 0) %b, i32 0, i32 0)
17+
%x = extractvalue { <4 x float>, i1 } %l, 0
18+
%v = extractelement <4 x float> %x, i32 0
19+
20+
ret float %v
21+
}
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
; RUN: opt -S -dxil-forward-handle-accesses -mtriple=dxil--shadermodel6.3-library %s | FileCheck %s
2+
3+
%"class.hlsl::RWStructuredBuffer" = type { target("dx.RawBuffer", <4 x float>, 1, 0) }
4+
5+
@_ZL2In = internal global %"class.hlsl::RWStructuredBuffer" poison, align 4
6+
@_ZL3Out = internal global %"class.hlsl::RWStructuredBuffer" poison, align 4
7+
8+
define void @main() #1 {
9+
entry:
10+
%this.addr.i.i.i = alloca ptr, align 4
11+
%this.addr.i.i = alloca ptr, align 4
12+
%this.addr.i1 = alloca ptr, align 4
13+
%Index.addr.i2 = alloca i32, align 4
14+
%this.addr.i = alloca ptr, align 4
15+
%Index.addr.i = alloca i32, align 4
16+
; CHECK: [[IN:%.*]] = call target("dx.RawBuffer", <4 x float>, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_v4f32_1_0t(i32 0, i32 0, i32 1, i32 0, i1 false)
17+
%_ZL2In_h.i.i = call target("dx.RawBuffer", <4 x float>, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_v4f32_1_0t(i32 0, i32 0, i32 1, i32 0, i1 false)
18+
store target("dx.RawBuffer", <4 x float>, 1, 0) %_ZL2In_h.i.i, ptr @_ZL2In, align 4
19+
store ptr @_ZL2In, ptr %this.addr.i.i, align 4
20+
%this1.i.i = load ptr, ptr %this.addr.i.i, align 4
21+
; CHECK: [[OUT:%.*]] = call target("dx.RawBuffer", <4 x float>, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_v4f32_1_0t(i32 0, i32 1, i32 1, i32 0, i1 false)
22+
%_ZL3Out_h.i.i = call target("dx.RawBuffer", <4 x float>, 1, 0) @llvm.dx.resource.handlefrombinding.tdx.RawBuffer_v4f32_1_0t(i32 0, i32 1, i32 1, i32 0, i1 false)
23+
store target("dx.RawBuffer", <4 x float>, 1, 0) %_ZL3Out_h.i.i, ptr @_ZL3Out, align 4
24+
store ptr @_ZL3Out, ptr %this.addr.i.i.i, align 4
25+
%this1.i.i.i = load ptr, ptr %this.addr.i.i.i, align 4
26+
store ptr @_ZL2In, ptr %this.addr.i1, align 4
27+
store i32 0, ptr %Index.addr.i2, align 4
28+
%this1.i3 = load ptr, ptr %this.addr.i1, align 4
29+
; CHECK-NOT: load target("dx.RawBuffer", <4 x float>, 1, 0)
30+
%0 = load target("dx.RawBuffer", <4 x float>, 1, 0), ptr %this1.i3, align 4
31+
%1 = load i32, ptr %Index.addr.i2, align 4
32+
; CHECK: call { <4 x float>, i1 } @llvm.dx.resource.load.rawbuffer.v4f32.tdx.RawBuffer_v4f32_1_0t(target("dx.RawBuffer", <4 x float>, 1, 0) [[IN]],
33+
%2 = call { <4 x float>, i1 } @llvm.dx.resource.load.rawbuffer.v4f32.tdx.RawBuffer_v4f32_1_0t(target("dx.RawBuffer", <4 x float>, 1, 0) %0, i32 %1, i32 0)
34+
%3 = extractvalue { <4 x float>, i1 } %2, 0
35+
store ptr @_ZL3Out, ptr %this.addr.i, align 4
36+
store i32 0, ptr %Index.addr.i, align 4
37+
%this1.i = load ptr, ptr %this.addr.i, align 4
38+
; CHECK-NOT: load target("dx.RawBuffer", <4 x float>, 1, 0)
39+
%4 = load target("dx.RawBuffer", <4 x float>, 1, 0), ptr %this1.i, align 4
40+
%5 = load i32, ptr %Index.addr.i, align 4
41+
; CHECK: call void @llvm.dx.resource.store.rawbuffer.tdx.RawBuffer_v4f32_1_0t.v4f32(target("dx.RawBuffer", <4 x float>, 1, 0) [[OUT]],
42+
call void @llvm.dx.resource.store.rawbuffer.tdx.RawBuffer_v4f32_1_0t.v4f32(target("dx.RawBuffer", <4 x float>, 1, 0) %4, i32 %5, i32 0, <4 x float> %3)
43+
ret void
44+
}

0 commit comments

Comments
 (0)