-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[DirectX] Add DXILRefineTypesPass to prevent type promotion
#169384
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| %struct.PromotedStruct = type { i16, half } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This pass could also be made more general and added to the opt pipeline. However, registering a callback to infer types for backend specific instructions is a little more complex, so I have left it just for DirectX atm.
|
@llvm/pr-subscribers-backend-directx Author: Finn Plummer (inbelic) ChangesProblem StatementConsider the following pattern: where This pattern is introduced by the The introduction of a load/store with an i32/i64 is problematic when lowering to DXIL because it can incidentally promote operations intended for a half/float or i16/i32. See this linked issue for an example of i32 operations being promoted to use an i64. While the promoted operations are functionally correct, it can introduce additional driver dependencies (eg: support for Int64) such that a shader may be valid for a driver when compiled with DXC but not clang, which is not acceptable. Proposed SolutionThis commit introduces the Currently, the pass only infers the type of a parent This pass is added to the In addition to listed below, this approach has the additional benefit for us to keep simplified form of the Resolves: #165753. Alternatives Considered
Full diff: https://github.com/llvm/llvm-project/pull/169384.diff 6 Files Affected:
diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt
index 6c079517e22d6..da61406eb88b8 100644
--- a/llvm/lib/Target/DirectX/CMakeLists.txt
+++ b/llvm/lib/Target/DirectX/CMakeLists.txt
@@ -31,6 +31,7 @@ add_llvm_target(DirectXCodeGen
DXILPostOptimizationValidation.cpp
DXILPrepare.cpp
DXILPrettyPrinter.cpp
+ DXILRefineTypes.cpp
DXILResourceAccess.cpp
DXILResourceImplicitBinding.cpp
DXILShaderFlags.cpp
diff --git a/llvm/lib/Target/DirectX/DXILRefineTypes.cpp b/llvm/lib/Target/DirectX/DXILRefineTypes.cpp
new file mode 100644
index 0000000000000..e5065768046d0
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILRefineTypes.cpp
@@ -0,0 +1,118 @@
+//===- DXILRefineTypes.cpp ----------------------------------------------===////
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "DXILRefineTypes.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/Module.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "dxil-refine-types"
+
+static Type *inferType(Value *Operand) {
+ if (auto *CI = dyn_cast<CallInst>(Operand))
+ if (CI->getCalledFunction()->getName().starts_with(
+ "llvm.dx.resource.getpointer"))
+ if (auto *ExtType = dyn_cast<TargetExtType>(CI->getOperand(0)->getType()))
+ return ExtType->getTypeParameter(0); // Valid for all dx.Types
+
+ if (auto *AI = dyn_cast<AllocaInst>(Operand))
+ return AI->getAllocatedType();
+
+ // TODO: Extend to other useful/applicable cases
+ return nullptr;
+}
+
+static Type *mergeInferredTypes(Type *A, Type *B) {
+ if (!A)
+ return B;
+
+ if (!B)
+ return A;
+
+ if (A == B)
+ return A;
+
+ // Otherwise, neither was inferred, or inferred differently
+ return nullptr;
+}
+
+bool DXILRefineTypesPass::runImpl(Function &F) {
+ // First detect the pattern: (generated from SimplifyAnyMemTransfer)
+ // %temp = load Type, ptr %src, ...
+ // store Type %temp, ptr %dest, ...
+ // where,
+ // store is the only user of %temp
+ // Type is either an i32 or i64
+ //
+ // We are currently only concerned with i32 and i64 as these can incidently
+ // promote 16/32 bit types to 32/64 bit arthimetic.
+ SmallVector<std::pair<LoadInst *, StoreInst *>, 4> ToVisit;
+ for (BasicBlock &BB : F)
+ for (Instruction &I : BB)
+ if (auto *LI = dyn_cast<LoadInst>(&I))
+ if (LI->hasOneUse())
+ if (auto *SI = dyn_cast<StoreInst>(LI->user_back()))
+ if (LI->getAccessType() == SI->getAccessType())
+ if (LI->getAccessType()->isIntegerTy(32) ||
+ LI->getAccessType()->isIntegerTy(64))
+ ToVisit.push_back({LI, SI});
+
+ bool Modified = false;
+ for (auto [LI, SI] : ToVisit) {
+ Type *LoadTy = inferType(LI->getPointerOperand());
+ Type *StoreTy = inferType(SI->getPointerOperand());
+
+ Type *const InferredTy = mergeInferredTypes(LoadTy, StoreTy);
+ if (!InferredTy || InferredTy == LI->getType())
+ continue; // Nothing to be done. Skip.
+
+ // Replace the type of the load/store
+ IRBuilder<> LoadBuilder(SI);
+ LoadInst *TypedLoad =
+ LoadBuilder.CreateLoad(InferredTy, LI->getPointerOperand());
+
+ TypedLoad->setAlignment(LI->getAlign());
+ TypedLoad->setVolatile(LI->isVolatile());
+ TypedLoad->setOrdering(LI->getOrdering());
+ TypedLoad->setAAMetadata(LI->getAAMetadata());
+ TypedLoad->copyMetadata(*LI, LLVMContext::MD_mem_parallel_loop_access);
+ TypedLoad->copyMetadata(*LI, LLVMContext::MD_access_group);
+
+ IRBuilder<> StoreBuilder(SI);
+ StoreInst *TypedStore = StoreBuilder.CreateStore(
+ TypedLoad, SI->getPointerOperand(), SI->isVolatile());
+
+ TypedStore->setAlignment(SI->getAlign());
+ TypedStore->setVolatile(SI->isVolatile());
+ TypedStore->setOrdering(SI->getOrdering());
+ TypedStore->setAAMetadata(SI->getAAMetadata());
+ TypedStore->copyMetadata(*SI, LLVMContext::MD_mem_parallel_loop_access);
+ TypedStore->copyMetadata(*SI, LLVMContext::MD_access_group);
+ TypedStore->copyMetadata(*SI, LLVMContext::MD_DIAssignID);
+
+ SI->eraseFromParent();
+ LI->eraseFromParent();
+
+ Modified = true;
+ }
+
+ return Modified;
+}
+
+PreservedAnalyses DXILRefineTypesPass::run(Function &F,
+ FunctionAnalysisManager &AM) {
+ if (!runImpl(F))
+ return PreservedAnalyses::all();
+
+ // TODO: Can probably preserve some CFG analyses
+ return PreservedAnalyses::none();
+}
diff --git a/llvm/lib/Target/DirectX/DXILRefineTypes.h b/llvm/lib/Target/DirectX/DXILRefineTypes.h
new file mode 100644
index 0000000000000..043e4eb2875f0
--- /dev/null
+++ b/llvm/lib/Target/DirectX/DXILRefineTypes.h
@@ -0,0 +1,40 @@
+//===- DXILRefineTypes.h - Infer additional type information ----*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This pass infers increased type information fidelity of memory operations and
+// replaces their uses with it. Primarily load/store.
+//
+// This is used to prevent the propogation of introduced problematic types. For
+// instance: the InstCombine pass can promote aggregates of 16/32-bit types to
+// be i32/64 loads and stores.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_TRANSFORMS_SCALAR_DXILREFINETYPES_H
+#define LLVM_TRANSFORMS_SCALAR_DXILREFINETYPES_H
+
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/PassManager.h"
+
+namespace llvm {
+
+class Function;
+
+class DXILRefineTypesPass : public PassInfoMixin<DXILRefineTypesPass> {
+private:
+ bool runImpl(Function &F);
+
+public:
+ DXILRefineTypesPass() = default;
+
+ PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
+};
+
+} // end namespace llvm
+
+#endif // LLVM_TRANSFORMS_SCALAR_DXILREFINETYPES_H
diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
index b4b48a166800e..8f034e4ba10ff 100644
--- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def
+++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def
@@ -42,6 +42,7 @@ MODULE_PASS("print<dxil-root-signature>", dxil::RootSignatureAnalysisPrinter(dbg
#define FUNCTION_PASS(NAME, CREATE_PASS)
#endif
FUNCTION_PASS("dxil-forward-handle-accesses", DXILForwardHandleAccesses())
+FUNCTION_PASS("dxil-refine-types", DXILRefineTypesPass())
FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
FUNCTION_PASS("dxil-legalize", DXILLegalizePass())
#undef FUNCTION_PASS
diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
index 84b1a313df2ea..785428962fe04 100644
--- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
+++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
@@ -22,6 +22,7 @@
#include "DXILOpLowering.h"
#include "DXILPostOptimizationValidation.h"
#include "DXILPrettyPrinter.h"
+#include "DXILRefineTypes.h"
#include "DXILResourceAccess.h"
#include "DXILResourceImplicitBinding.h"
#include "DXILRootSignature.h"
@@ -37,10 +38,12 @@
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/LegacyPassManager.h"
+#include "llvm/IR/PassManager.h"
#include "llvm/InitializePasses.h"
#include "llvm/MC/MCSectionDXContainer.h"
#include "llvm/MC/SectionKind.h"
#include "llvm/MC/TargetRegistry.h"
+#include "llvm/Passes/OptimizationLevel.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/Compiler.h"
@@ -147,6 +150,11 @@ DirectXTargetMachine::~DirectXTargetMachine() {}
void DirectXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
#define GET_PASS_REGISTRY "DirectXPassRegistry.def"
#include "llvm/Passes/TargetPassRegistry.inc"
+
+ PB.registerPeepholeEPCallback(
+ [](FunctionPassManager &FPM, OptimizationLevel) {
+ FPM.addPass(DXILRefineTypesPass());
+ });
}
bool DirectXTargetMachine::addPassesToEmitFile(
diff --git a/llvm/test/CodeGen/DirectX/refine-access-load-store.ll b/llvm/test/CodeGen/DirectX/refine-access-load-store.ll
new file mode 100644
index 0000000000000..b3f9b9c13e338
--- /dev/null
+++ b/llvm/test/CodeGen/DirectX/refine-access-load-store.ll
@@ -0,0 +1,150 @@
+; RUN: split-file %s %t
+; RUN: opt -passes=dxil-refine-types -S -mtriple=dxil-pc-shadermodel6.0-library < %t/explicit-64.ll | FileCheck %s
+; RUN: opt -passes='instcombine,dxil-refine-types' -S -mtriple=dxil-pc-shadermodel6.0-library < %t/generated-64.ll | FileCheck %s
+; RUN: opt -passes=dxil-refine-types -S -mtriple=dxil-pc-shadermodel6.0-library < %t/explicit-32.ll | FileCheck %s
+; RUN: opt -passes='instcombine,dxil-refine-types' -S -mtriple=dxil-pc-shadermodel6.0-library < %t/generated-32.ll | FileCheck %s
+
+; RUN: opt -passes='instcombine,dxil-refine-types' -S -mtriple=dxil-pc-shadermodel6.0-library < %t/folded.ll | FileCheck %s --check-prefix=FOLDED
+
+; Tests that dxil-refine-types will catch the access pattern generated by inst-combine
+
+; CHECK-LABEL: @test(
+; CHECK: %[[#FROM:]] = load %struct.PromotedStruct, ptr %get_access
+; CHECK: store %struct.PromotedStruct %[[#FROM]], ptr %param
+; CHECK: call void @external_barrier(ptr{{.*}}%param)
+; CHECK: %[[#TO:]] = load %struct.PromotedStruct, ptr %param
+; CHECK: store %struct.PromotedStruct %[[#TO]], ptr %set_access
+; CHECK: ret void
+
+;--- explicit-64.ll
+
+%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
+%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
+%struct.PromotedStruct = type { i32, float }
+
+@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
+@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"
+
+define void @test(i32 %idx) {
+ %param = alloca %struct.PromotedStruct, align 1
+ %src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
+ %get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
+ %1 = load i64, ptr %get_access, align 1
+ store i64 %1, ptr %param, align 1
+
+ call void @external_barrier(ptr %param)
+
+ %dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
+ %set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
+ %2 = load i64, ptr %param, align 1
+ store i64 %2, ptr %set_access, align 1
+ ret void
+}
+
+declare void @external_barrier(ptr)
+
+;--- generated-64.ll
+
+%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
+%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
+%struct.PromotedStruct = type { i32, float }
+
+@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
+@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"
+
+define void @test(i32 %idx) {
+ %param = alloca %struct.PromotedStruct, align 1
+ %src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
+ %get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
+ call void @llvm.memcpy.p0.p0.i32(ptr align 1 %param, ptr align 1 %get_access, i32 8, i1 false)
+
+ call void @external_barrier(ptr %param)
+
+ %dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
+ %set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
+ call void @llvm.memcpy.p0.p0.i32(ptr align 1 %set_access, ptr align 1 %param, i32 8, i1 false)
+ ret void
+}
+
+declare void @external_barrier(ptr)
+
+;--- explicit-32.ll
+
+%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
+%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
+%struct.PromotedStruct = type { i16, half }
+
+@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
+@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"
+
+define void @test(i32 %idx) {
+ %param = alloca %struct.PromotedStruct, align 1
+ %src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
+ %get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
+ %1 = load i32, ptr %get_access, align 1
+ store i32 %1, ptr %param, align 1
+
+ call void @external_barrier(ptr %param)
+
+ %dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
+ %set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
+ %2 = load i32, ptr %param, align 1
+ store i32 %2, ptr %set_access, align 1
+ ret void
+}
+
+declare void @external_barrier(ptr)
+
+;--- generated-32.ll
+
+%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
+%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
+%struct.PromotedStruct = type { i32, i32 }
+
+@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
+@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"
+
+define void @test(i32 %idx) {
+ %param = alloca %struct.PromotedStruct, align 1
+ %src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
+ %get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
+ call void @llvm.memcpy.p0.p0.i32(ptr align 1 %param, ptr align 1 %get_access, i32 8, i1 false)
+
+ call void @external_barrier(ptr %param)
+
+ %dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
+ %set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
+ call void @llvm.memcpy.p0.p0.i32(ptr align 1 %set_access, ptr align 1 %param, i32 8, i1 false)
+ ret void
+}
+
+declare void @external_barrier(ptr)
+
+;--- folded.ll
+
+; This tests that when there is no function call to prevent folding, the type
+; will still successfully be replaced.
+
+; FOLDED-LABEL: @test_folded(
+; FOLDED: %[[#TO:]] = load %struct.PromotedStruct, ptr %get_access
+; FOLDED: store %struct.PromotedStruct %[[#TO]], ptr %set_access
+; FOLDED: ret void
+
+%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
+%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
+%struct.PromotedStruct = type { i32, float }
+
+@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
+@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"
+
+define void @test_folded(i32 %idx) {
+ %param = alloca %struct.PromotedStruct, align 1
+ %src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
+ %get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
+ call void @llvm.memcpy.p0.p0.i32(ptr align 1 %param, ptr align 1 %get_access, i32 8, i1 false)
+
+ %dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
+ %set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
+ call void @llvm.memcpy.p0.p0.i32(ptr align 1 %set_access, ptr align 1 %param, i32 8, i1 false)
+ ret void
+}
|
Icohedron
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Perhaps the LLVM PatternMatch API could be used to simplify the pass a bit
joaosaffran
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR comments is really long, and most of the discussion there seems to be more relevant in the issue than in patch it self.
Some minor nits regarding code readability and handling instructions that we missed.
Other than that LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I find this suggestion easier to read
| if (auto *CI = dyn_cast<CallInst>(Operand)){ | |
| bool isGetPointer = CI->getCalledFunction()->getName().starts_with( | |
| "llvm.dx.resource.getpointer"); | |
| auto *ExtType = dyn_cast<TargetExtType>(CI->getOperand(0)->getType()) | |
| if (isGetPointer && ExtType) | |
| return ExtType->getTypeParameter(0); | |
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
| else if (auto *AI = dyn_cast<AllocaInst>(Operand)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This allows us to know if we are generating code that needs to be handled.
| llvm_unreachable("Instruction not handled in DXILRefineTypes" |
|
From offline discussion with @bogner it looks like we can try to infer the type directly in the It appears that just inferring from the I will open up a pr with reference to this one for context. |
Problem Statement
Consider the following pattern:
where
Typeis ani32/i64and thestoreis the only user of%temp.This pattern is introduced by the
SimplifyAnyMemTransferfunction during theInstCombineand convertsllvm.memcpys of size 1/2/4/8 bytes to their load/store equivalent with the corresponding integer type. Thellvm.memcpys are frequently introduced when handling resource buffer accesses in HLSL.The introduction of a load/store with an i32/i64 is problematic when lowering to DXIL because it can incidentally promote operations intended for a half/float or i16/i32. See this linked issue for an example of i32 operations being promoted to use an i64.
While the promoted operations are functionally correct, it can introduce additional driver dependencies (eg: support for Int64) such that a shader may be valid for a driver when compiled with DXC but not clang, which is not acceptable.
Proposed Solution
This commit introduces the
DXILRefineTypesPassthat will match on the pattern above and infer the type based on one of thesrc/destpointer arguments, and then upgrade the types into their higher fidelity types. By doing so right after they are introduced, it will prevent any incidental usage of the i32/64.Currently, the pass only infers the type of a parent
allocaordxil.resource.getpointerinstruction for the minimal change to resolve the linked issue. However it could be easily extended to infer types from more parent operations if an opportunity arises.This pass is added to the
registerPeepholeEPCallbackwhich will appropriately run directly after theInstCombineintroducing the pattern.In addition to listed below, this approach has the additional benefit for us to keep simplified form of the
llvm.memcpyas a load/store.Resolves: #165753.
Alternatives Considered
llvm.memcpys to be removed by component accesses prior to invoking theInstCombinepass. This would functionally resolve the linked issue, however, inlining much earlier would (most likely) result in a noticeable increase for compile time.InstCombine. We could alternatively provide a new parameter (or otherwise) to theInstCombinepass that would prevent the promotion of certain types. This would be a general option that other languages can enable when applicable. This would also resolve the issue by ensuring thellvm.memcpyis not modified during the pass and allow for greater certainty that this would not slip by in the future (from my perspective it seems unlikely). However, this is a much more obtrusive change than the proposed solution.