Skip to content

Commit d68261f

Browse files
committed
mvp
1 parent 4637bf0 commit d68261f

File tree

6 files changed

+319
-0
lines changed

6 files changed

+319
-0
lines changed

llvm/lib/Target/DirectX/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ add_llvm_target(DirectXCodeGen
3131
DXILPostOptimizationValidation.cpp
3232
DXILPrepare.cpp
3333
DXILPrettyPrinter.cpp
34+
DXILRefineTypes.cpp
3435
DXILResourceAccess.cpp
3536
DXILResourceImplicitBinding.cpp
3637
DXILShaderFlags.cpp
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
//===- DXILRefineTypes.cpp ----------------------------------------------===////
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "DXILRefineTypes.h"
10+
#include "llvm/IR/DerivedTypes.h"
11+
#include "llvm/IR/GlobalVariable.h"
12+
#include "llvm/IR/IRBuilder.h"
13+
#include "llvm/IR/Instructions.h"
14+
#include "llvm/IR/Intrinsics.h"
15+
#include "llvm/IR/Module.h"
16+
17+
using namespace llvm;
18+
19+
#define DEBUG_TYPE "dxil-refine-types"
20+
21+
static Type *inferType(Value *Operand) {
22+
if (auto *CI = dyn_cast<CallInst>(Operand))
23+
if (CI->getCalledFunction()->getName().starts_with(
24+
"llvm.dx.resource.getpointer"))
25+
if (auto *ExtType = dyn_cast<TargetExtType>(CI->getOperand(0)->getType()))
26+
return ExtType->getTypeParameter(0); // Valid for all dx.Types
27+
28+
if (auto *AI = dyn_cast<AllocaInst>(Operand))
29+
return AI->getAllocatedType();
30+
31+
// TODO: Extend to other useful/applicable cases
32+
return nullptr;
33+
}
34+
35+
// Attempt to merge two inferred types.
36+
//
37+
// Returns nullptr, if an inferred can't be concluded
38+
static Type *mergeInferredTypes(Type *A, Type *B) {
39+
if (!A)
40+
return B;
41+
42+
if (!B)
43+
return A;
44+
45+
if (A == B)
46+
return A;
47+
48+
// Otherwise, neither was inferred, or inferred differently
49+
return nullptr;
50+
}
51+
52+
bool DXILRefineTypesPass::runImpl(Function &F) {
53+
// First detect the pattern: (generated from SimplifyAnyMemTransfer)
54+
// %temp = load Type, ptr %src, ...
55+
// store Type %temp, ptr %dest, ...
56+
// where,
57+
// store is the only user of %temp
58+
// Type is either an i32 or i64
59+
//
60+
// We are currently only concerned with i32 and i64 as these can incidently
61+
// promote 16/32 bit types to 32/64 bit arthimetic.
62+
SmallVector<std::pair<LoadInst *, StoreInst *>, 4> ToVisit;
63+
for (BasicBlock &BB : F)
64+
for (Instruction &I : BB)
65+
if (auto *LI = dyn_cast<LoadInst>(&I))
66+
if (LI->hasOneUse())
67+
if (auto *SI = dyn_cast<StoreInst>(LI->user_back()))
68+
if (LI->getAccessType() == SI->getAccessType())
69+
if (LI->getAccessType()->isIntegerTy(32) ||
70+
LI->getAccessType()->isIntegerTy(64))
71+
ToVisit.push_back({LI, SI});
72+
73+
bool Modified = false;
74+
for (auto [LI, SI] : ToVisit) {
75+
Type *LoadTy = inferType(LI->getPointerOperand());
76+
Type *StoreTy = inferType(SI->getPointerOperand());
77+
78+
Type *const InferredTy = mergeInferredTypes(LoadTy, StoreTy);
79+
if (!InferredTy || InferredTy == LI->getType())
80+
continue; // Nothing to be done. Skip.
81+
82+
// Replace the type of the load/store
83+
IRBuilder<> LoadBuilder(SI);
84+
LoadInst *TypedLoad =
85+
LoadBuilder.CreateLoad(InferredTy, LI->getPointerOperand());
86+
87+
TypedLoad->setAlignment(LI->getAlign());
88+
TypedLoad->setVolatile(LI->isVolatile());
89+
TypedLoad->setOrdering(LI->getOrdering());
90+
TypedLoad->setAAMetadata(LI->getAAMetadata());
91+
TypedLoad->copyMetadata(*LI, LLVMContext::MD_mem_parallel_loop_access);
92+
TypedLoad->copyMetadata(*LI, LLVMContext::MD_access_group);
93+
94+
IRBuilder<> StoreBuilder(SI);
95+
StoreInst *TypedStore = StoreBuilder.CreateStore(
96+
TypedLoad, SI->getPointerOperand(), SI->isVolatile());
97+
98+
TypedStore->setAlignment(SI->getAlign());
99+
TypedStore->setVolatile(SI->isVolatile());
100+
TypedStore->setOrdering(SI->getOrdering());
101+
TypedStore->setAAMetadata(SI->getAAMetadata());
102+
TypedStore->copyMetadata(*SI, LLVMContext::MD_mem_parallel_loop_access);
103+
TypedStore->copyMetadata(*SI, LLVMContext::MD_access_group);
104+
TypedStore->copyMetadata(*SI, LLVMContext::MD_DIAssignID);
105+
106+
SI->eraseFromParent();
107+
LI->eraseFromParent();
108+
109+
Modified = true;
110+
}
111+
112+
return Modified;
113+
}
114+
115+
PreservedAnalyses DXILRefineTypesPass::run(Function &F,
116+
FunctionAnalysisManager &AM) {
117+
if (!runImpl(F))
118+
return PreservedAnalyses::all();
119+
120+
// TODO: Can probably preserve some CFG analyses
121+
return PreservedAnalyses::none();
122+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
//===- DXILRefineTypes.h - Infer additional type information ----*- C++ -*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
//
9+
// This pass infers increased type information fidelity of memory operations and
10+
// replaces their uses with it. Primarily load/store.
11+
//
12+
// This is used to prevent the propogation of introduced problematic types. For
13+
// instance: the InstCombine pass can promote aggregates of 16/32-bit types to
14+
// be i32/64 loads and stores.
15+
//
16+
//===----------------------------------------------------------------------===//
17+
18+
#ifndef LLVM_TRANSFORMS_SCALAR_DXILREFINETYPES_H
19+
#define LLVM_TRANSFORMS_SCALAR_DXILREFINETYPES_H
20+
21+
#include "llvm/IR/BasicBlock.h"
22+
#include "llvm/IR/PassManager.h"
23+
24+
namespace llvm {
25+
26+
class Function;
27+
28+
class DXILRefineTypesPass : public PassInfoMixin<DXILRefineTypesPass> {
29+
private:
30+
bool runImpl(Function &F);
31+
32+
public:
33+
DXILRefineTypesPass() = default;
34+
35+
PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
36+
};
37+
38+
} // end namespace llvm
39+
40+
#endif // LLVM_TRANSFORMS_SCALAR_DXILREFINETYPES_H

llvm/lib/Target/DirectX/DirectXPassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ MODULE_PASS("print<dxil-root-signature>", dxil::RootSignatureAnalysisPrinter(dbg
4242
#define FUNCTION_PASS(NAME, CREATE_PASS)
4343
#endif
4444
FUNCTION_PASS("dxil-forward-handle-accesses", DXILForwardHandleAccesses())
45+
FUNCTION_PASS("dxil-refine-types", DXILRefineTypesPass())
4546
FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
4647
FUNCTION_PASS("dxil-legalize", DXILLegalizePass())
4748
#undef FUNCTION_PASS

llvm/lib/Target/DirectX/DirectXTargetMachine.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "DXILOpLowering.h"
2323
#include "DXILPostOptimizationValidation.h"
2424
#include "DXILPrettyPrinter.h"
25+
#include "DXILRefineTypes.h"
2526
#include "DXILResourceAccess.h"
2627
#include "DXILResourceImplicitBinding.h"
2728
#include "DXILRootSignature.h"
@@ -37,10 +38,12 @@
3738
#include "llvm/CodeGen/TargetPassConfig.h"
3839
#include "llvm/IR/IRPrintingPasses.h"
3940
#include "llvm/IR/LegacyPassManager.h"
41+
#include "llvm/IR/PassManager.h"
4042
#include "llvm/InitializePasses.h"
4143
#include "llvm/MC/MCSectionDXContainer.h"
4244
#include "llvm/MC/SectionKind.h"
4345
#include "llvm/MC/TargetRegistry.h"
46+
#include "llvm/Passes/OptimizationLevel.h"
4447
#include "llvm/Passes/PassBuilder.h"
4548
#include "llvm/Support/CodeGen.h"
4649
#include "llvm/Support/Compiler.h"
@@ -147,6 +150,11 @@ DirectXTargetMachine::~DirectXTargetMachine() {}
147150
void DirectXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
148151
#define GET_PASS_REGISTRY "DirectXPassRegistry.def"
149152
#include "llvm/Passes/TargetPassRegistry.inc"
153+
154+
PB.registerPeepholeEPCallback(
155+
[](FunctionPassManager &FPM, OptimizationLevel) {
156+
FPM.addPass(DXILRefineTypesPass());
157+
});
150158
}
151159

152160
bool DirectXTargetMachine::addPassesToEmitFile(
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
; RUN: split-file %s %t
2+
; RUN: opt -passes=dxil-refine-types -S -mtriple=dxil-pc-shadermodel6.0-library < %t/explicit-64.ll | FileCheck %s
3+
; RUN: opt -passes='instcombine,dxil-refine-types' -S -mtriple=dxil-pc-shadermodel6.0-library < %t/generated-64.ll | FileCheck %s
4+
; RUN: opt -passes=dxil-refine-types -S -mtriple=dxil-pc-shadermodel6.0-library < %t/explicit-32.ll | FileCheck %s
5+
; RUN: opt -passes='instcombine,dxil-refine-types' -S -mtriple=dxil-pc-shadermodel6.0-library < %t/generated-32.ll | FileCheck %s
6+
7+
; RUN: opt -passes='instcombine,dxil-refine-types' -S -mtriple=dxil-pc-shadermodel6.0-library < %t/folded.ll | FileCheck %s --check-prefix=FOLDED
8+
9+
; Tests that dxil-refine-types will catch the access pattern generated by inst-combine
10+
11+
; CHECK-LABEL: @test(
12+
; CHECK: %[[#FROM:]] = load %struct.PromotedStruct, ptr %get_access
13+
; CHECK: store %struct.PromotedStruct %[[#FROM]], ptr %param
14+
; CHECK: call void @external_barrier(ptr{{.*}}%param)
15+
; CHECK: %[[#TO:]] = load %struct.PromotedStruct, ptr %param
16+
; CHECK: store %struct.PromotedStruct %[[#TO]], ptr %set_access
17+
; CHECK: ret void
18+
19+
;--- explicit-64.ll
20+
21+
%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
22+
%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
23+
%struct.PromotedStruct = type { i32, float }
24+
25+
@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
26+
@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"
27+
28+
define void @test(i32 %idx) {
29+
%param = alloca %struct.PromotedStruct, align 1
30+
%src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
31+
%get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
32+
%1 = load i64, ptr %get_access, align 1
33+
store i64 %1, ptr %param, align 1
34+
35+
call void @external_barrier(ptr %param)
36+
37+
%dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
38+
%set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
39+
%2 = load i64, ptr %param, align 1
40+
store i64 %2, ptr %set_access, align 1
41+
ret void
42+
}
43+
44+
declare void @external_barrier(ptr)
45+
46+
;--- generated-64.ll
47+
48+
%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
49+
%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
50+
%struct.PromotedStruct = type { i32, float }
51+
52+
@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
53+
@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"
54+
55+
define void @test(i32 %idx) {
56+
%param = alloca %struct.PromotedStruct, align 1
57+
%src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
58+
%get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
59+
call void @llvm.memcpy.p0.p0.i32(ptr align 1 %param, ptr align 1 %get_access, i32 8, i1 false)
60+
61+
call void @external_barrier(ptr %param)
62+
63+
%dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
64+
%set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
65+
call void @llvm.memcpy.p0.p0.i32(ptr align 1 %set_access, ptr align 1 %param, i32 8, i1 false)
66+
ret void
67+
}
68+
69+
declare void @external_barrier(ptr)
70+
71+
;--- explicit-32.ll
72+
73+
%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
74+
%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
75+
%struct.PromotedStruct = type { i16, half }
76+
77+
@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
78+
@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"
79+
80+
define void @test(i32 %idx) {
81+
%param = alloca %struct.PromotedStruct, align 1
82+
%src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
83+
%get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
84+
%1 = load i32, ptr %get_access, align 1
85+
store i32 %1, ptr %param, align 1
86+
87+
call void @external_barrier(ptr %param)
88+
89+
%dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
90+
%set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
91+
%2 = load i32, ptr %param, align 1
92+
store i32 %2, ptr %set_access, align 1
93+
ret void
94+
}
95+
96+
declare void @external_barrier(ptr)
97+
98+
;--- generated-32.ll
99+
100+
%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
101+
%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
102+
%struct.PromotedStruct = type { i32, i32 }
103+
104+
@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
105+
@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"
106+
107+
define void @test(i32 %idx) {
108+
%param = alloca %struct.PromotedStruct, align 1
109+
%src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
110+
%get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
111+
call void @llvm.memcpy.p0.p0.i32(ptr align 1 %param, ptr align 1 %get_access, i32 8, i1 false)
112+
113+
call void @external_barrier(ptr %param)
114+
115+
%dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
116+
%set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
117+
call void @llvm.memcpy.p0.p0.i32(ptr align 1 %set_access, ptr align 1 %param, i32 8, i1 false)
118+
ret void
119+
}
120+
121+
declare void @external_barrier(ptr)
122+
123+
;--- folded.ll
124+
125+
; FOLDED-LABEL: @test_folded(
126+
; FOLDED: %[[#TO:]] = load %struct.PromotedStruct, ptr %get_access
127+
; FOLDED: store %struct.PromotedStruct %[[#TO]], ptr %set_access
128+
; FOLDED: ret void
129+
130+
%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
131+
%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
132+
%struct.PromotedStruct = type { i32, float }
133+
134+
@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
135+
@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"
136+
137+
define void @test_folded(i32 %idx) {
138+
%param = alloca %struct.PromotedStruct, align 1
139+
%src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
140+
%get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
141+
call void @llvm.memcpy.p0.p0.i32(ptr align 1 %param, ptr align 1 %get_access, i32 8, i1 false)
142+
143+
%dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
144+
%set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
145+
call void @llvm.memcpy.p0.p0.i32(ptr align 1 %set_access, ptr align 1 %param, i32 8, i1 false)
146+
ret void
147+
}

0 commit comments

Comments
 (0)