diff --git a/llvm/lib/Target/DirectX/CMakeLists.txt b/llvm/lib/Target/DirectX/CMakeLists.txt index 6c079517e22d6..da61406eb88b8 100644 --- a/llvm/lib/Target/DirectX/CMakeLists.txt +++ b/llvm/lib/Target/DirectX/CMakeLists.txt @@ -31,6 +31,7 @@ add_llvm_target(DirectXCodeGen DXILPostOptimizationValidation.cpp DXILPrepare.cpp DXILPrettyPrinter.cpp + DXILRefineTypes.cpp DXILResourceAccess.cpp DXILResourceImplicitBinding.cpp DXILShaderFlags.cpp diff --git a/llvm/lib/Target/DirectX/DXILRefineTypes.cpp b/llvm/lib/Target/DirectX/DXILRefineTypes.cpp new file mode 100644 index 0000000000000..e5065768046d0 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILRefineTypes.cpp @@ -0,0 +1,118 @@ +//===- DXILRefineTypes.cpp ----------------------------------------------===//// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "DXILRefineTypes.h" +#include "llvm/IR/DerivedTypes.h" +#include "llvm/IR/IRBuilder.h" +#include "llvm/IR/Instructions.h" +#include "llvm/IR/Intrinsics.h" +#include "llvm/IR/Module.h" + +using namespace llvm; + +#define DEBUG_TYPE "dxil-refine-types" + +static Type *inferType(Value *Operand) { + if (auto *CI = dyn_cast(Operand)) + if (CI->getCalledFunction()->getName().starts_with( + "llvm.dx.resource.getpointer")) + if (auto *ExtType = dyn_cast(CI->getOperand(0)->getType())) + return ExtType->getTypeParameter(0); // Valid for all dx.Types + + if (auto *AI = dyn_cast(Operand)) + return AI->getAllocatedType(); + + // TODO: Extend to other useful/applicable cases + return nullptr; +} + +static Type *mergeInferredTypes(Type *A, Type *B) { + if (!A) + return B; + + if (!B) + return A; + + if (A == B) + return A; + + // Otherwise, neither was inferred, or inferred differently + return nullptr; +} + +bool DXILRefineTypesPass::runImpl(Function &F) { + // First detect the pattern: (generated from SimplifyAnyMemTransfer) + // %temp = load Type, ptr %src, ... + // store Type %temp, ptr %dest, ... + // where, + // store is the only user of %temp + // Type is either an i32 or i64 + // + // We are currently only concerned with i32 and i64 as these can incidently + // promote 16/32 bit types to 32/64 bit arthimetic. + SmallVector, 4> ToVisit; + for (BasicBlock &BB : F) + for (Instruction &I : BB) + if (auto *LI = dyn_cast(&I)) + if (LI->hasOneUse()) + if (auto *SI = dyn_cast(LI->user_back())) + if (LI->getAccessType() == SI->getAccessType()) + if (LI->getAccessType()->isIntegerTy(32) || + LI->getAccessType()->isIntegerTy(64)) + ToVisit.push_back({LI, SI}); + + bool Modified = false; + for (auto [LI, SI] : ToVisit) { + Type *LoadTy = inferType(LI->getPointerOperand()); + Type *StoreTy = inferType(SI->getPointerOperand()); + + Type *const InferredTy = mergeInferredTypes(LoadTy, StoreTy); + if (!InferredTy || InferredTy == LI->getType()) + continue; // Nothing to be done. Skip. + + // Replace the type of the load/store + IRBuilder<> LoadBuilder(SI); + LoadInst *TypedLoad = + LoadBuilder.CreateLoad(InferredTy, LI->getPointerOperand()); + + TypedLoad->setAlignment(LI->getAlign()); + TypedLoad->setVolatile(LI->isVolatile()); + TypedLoad->setOrdering(LI->getOrdering()); + TypedLoad->setAAMetadata(LI->getAAMetadata()); + TypedLoad->copyMetadata(*LI, LLVMContext::MD_mem_parallel_loop_access); + TypedLoad->copyMetadata(*LI, LLVMContext::MD_access_group); + + IRBuilder<> StoreBuilder(SI); + StoreInst *TypedStore = StoreBuilder.CreateStore( + TypedLoad, SI->getPointerOperand(), SI->isVolatile()); + + TypedStore->setAlignment(SI->getAlign()); + TypedStore->setVolatile(SI->isVolatile()); + TypedStore->setOrdering(SI->getOrdering()); + TypedStore->setAAMetadata(SI->getAAMetadata()); + TypedStore->copyMetadata(*SI, LLVMContext::MD_mem_parallel_loop_access); + TypedStore->copyMetadata(*SI, LLVMContext::MD_access_group); + TypedStore->copyMetadata(*SI, LLVMContext::MD_DIAssignID); + + SI->eraseFromParent(); + LI->eraseFromParent(); + + Modified = true; + } + + return Modified; +} + +PreservedAnalyses DXILRefineTypesPass::run(Function &F, + FunctionAnalysisManager &AM) { + if (!runImpl(F)) + return PreservedAnalyses::all(); + + // TODO: Can probably preserve some CFG analyses + return PreservedAnalyses::none(); +} diff --git a/llvm/lib/Target/DirectX/DXILRefineTypes.h b/llvm/lib/Target/DirectX/DXILRefineTypes.h new file mode 100644 index 0000000000000..043e4eb2875f0 --- /dev/null +++ b/llvm/lib/Target/DirectX/DXILRefineTypes.h @@ -0,0 +1,40 @@ +//===- DXILRefineTypes.h - Infer additional type information ----*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This pass infers increased type information fidelity of memory operations and +// replaces their uses with it. Primarily load/store. +// +// This is used to prevent the propogation of introduced problematic types. For +// instance: the InstCombine pass can promote aggregates of 16/32-bit types to +// be i32/64 loads and stores. +// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_TRANSFORMS_SCALAR_DXILREFINETYPES_H +#define LLVM_TRANSFORMS_SCALAR_DXILREFINETYPES_H + +#include "llvm/IR/BasicBlock.h" +#include "llvm/IR/PassManager.h" + +namespace llvm { + +class Function; + +class DXILRefineTypesPass : public PassInfoMixin { +private: + bool runImpl(Function &F); + +public: + DXILRefineTypesPass() = default; + + PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM); +}; + +} // end namespace llvm + +#endif // LLVM_TRANSFORMS_SCALAR_DXILREFINETYPES_H diff --git a/llvm/lib/Target/DirectX/DirectXPassRegistry.def b/llvm/lib/Target/DirectX/DirectXPassRegistry.def index b4b48a166800e..8f034e4ba10ff 100644 --- a/llvm/lib/Target/DirectX/DirectXPassRegistry.def +++ b/llvm/lib/Target/DirectX/DirectXPassRegistry.def @@ -42,6 +42,7 @@ MODULE_PASS("print", dxil::RootSignatureAnalysisPrinter(dbg #define FUNCTION_PASS(NAME, CREATE_PASS) #endif FUNCTION_PASS("dxil-forward-handle-accesses", DXILForwardHandleAccesses()) +FUNCTION_PASS("dxil-refine-types", DXILRefineTypesPass()) FUNCTION_PASS("dxil-resource-access", DXILResourceAccess()) FUNCTION_PASS("dxil-legalize", DXILLegalizePass()) #undef FUNCTION_PASS diff --git a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp index 84b1a313df2ea..785428962fe04 100644 --- a/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp +++ b/llvm/lib/Target/DirectX/DirectXTargetMachine.cpp @@ -22,6 +22,7 @@ #include "DXILOpLowering.h" #include "DXILPostOptimizationValidation.h" #include "DXILPrettyPrinter.h" +#include "DXILRefineTypes.h" #include "DXILResourceAccess.h" #include "DXILResourceImplicitBinding.h" #include "DXILRootSignature.h" @@ -37,10 +38,12 @@ #include "llvm/CodeGen/TargetPassConfig.h" #include "llvm/IR/IRPrintingPasses.h" #include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/PassManager.h" #include "llvm/InitializePasses.h" #include "llvm/MC/MCSectionDXContainer.h" #include "llvm/MC/SectionKind.h" #include "llvm/MC/TargetRegistry.h" +#include "llvm/Passes/OptimizationLevel.h" #include "llvm/Passes/PassBuilder.h" #include "llvm/Support/CodeGen.h" #include "llvm/Support/Compiler.h" @@ -147,6 +150,11 @@ DirectXTargetMachine::~DirectXTargetMachine() {} void DirectXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) { #define GET_PASS_REGISTRY "DirectXPassRegistry.def" #include "llvm/Passes/TargetPassRegistry.inc" + + PB.registerPeepholeEPCallback( + [](FunctionPassManager &FPM, OptimizationLevel) { + FPM.addPass(DXILRefineTypesPass()); + }); } bool DirectXTargetMachine::addPassesToEmitFile( diff --git a/llvm/test/CodeGen/DirectX/refine-access-load-store.ll b/llvm/test/CodeGen/DirectX/refine-access-load-store.ll new file mode 100644 index 0000000000000..b3f9b9c13e338 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/refine-access-load-store.ll @@ -0,0 +1,150 @@ +; RUN: split-file %s %t +; RUN: opt -passes=dxil-refine-types -S -mtriple=dxil-pc-shadermodel6.0-library < %t/explicit-64.ll | FileCheck %s +; RUN: opt -passes='instcombine,dxil-refine-types' -S -mtriple=dxil-pc-shadermodel6.0-library < %t/generated-64.ll | FileCheck %s +; RUN: opt -passes=dxil-refine-types -S -mtriple=dxil-pc-shadermodel6.0-library < %t/explicit-32.ll | FileCheck %s +; RUN: opt -passes='instcombine,dxil-refine-types' -S -mtriple=dxil-pc-shadermodel6.0-library < %t/generated-32.ll | FileCheck %s + +; RUN: opt -passes='instcombine,dxil-refine-types' -S -mtriple=dxil-pc-shadermodel6.0-library < %t/folded.ll | FileCheck %s --check-prefix=FOLDED + +; Tests that dxil-refine-types will catch the access pattern generated by inst-combine + +; CHECK-LABEL: @test( +; CHECK: %[[#FROM:]] = load %struct.PromotedStruct, ptr %get_access +; CHECK: store %struct.PromotedStruct %[[#FROM]], ptr %param +; CHECK: call void @external_barrier(ptr{{.*}}%param) +; CHECK: %[[#TO:]] = load %struct.PromotedStruct, ptr %param +; CHECK: store %struct.PromotedStruct %[[#TO]], ptr %set_access +; CHECK: ret void + +;--- explicit-64.ll + +%"StructuredBuffer" = type { %struct.PromotedStruct } +%"RWStructuredBuffer" = type { %struct.PromotedStruct } +%struct.PromotedStruct = type { i32, float } + +@src = external constant %"StructuredBuffer" +@dest = external constant %"RWStructuredBuffer" + +define void @test(i32 %idx) { + %param = alloca %struct.PromotedStruct, align 1 + %src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4 + %get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx) + %1 = load i64, ptr %get_access, align 1 + store i64 %1, ptr %param, align 1 + + call void @external_barrier(ptr %param) + + %dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4 + %set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx) + %2 = load i64, ptr %param, align 1 + store i64 %2, ptr %set_access, align 1 + ret void +} + +declare void @external_barrier(ptr) + +;--- generated-64.ll + +%"StructuredBuffer" = type { %struct.PromotedStruct } +%"RWStructuredBuffer" = type { %struct.PromotedStruct } +%struct.PromotedStruct = type { i32, float } + +@src = external constant %"StructuredBuffer" +@dest = external constant %"RWStructuredBuffer" + +define void @test(i32 %idx) { + %param = alloca %struct.PromotedStruct, align 1 + %src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4 + %get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx) + call void @llvm.memcpy.p0.p0.i32(ptr align 1 %param, ptr align 1 %get_access, i32 8, i1 false) + + call void @external_barrier(ptr %param) + + %dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4 + %set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx) + call void @llvm.memcpy.p0.p0.i32(ptr align 1 %set_access, ptr align 1 %param, i32 8, i1 false) + ret void +} + +declare void @external_barrier(ptr) + +;--- explicit-32.ll + +%"StructuredBuffer" = type { %struct.PromotedStruct } +%"RWStructuredBuffer" = type { %struct.PromotedStruct } +%struct.PromotedStruct = type { i16, half } + +@src = external constant %"StructuredBuffer" +@dest = external constant %"RWStructuredBuffer" + +define void @test(i32 %idx) { + %param = alloca %struct.PromotedStruct, align 1 + %src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4 + %get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx) + %1 = load i32, ptr %get_access, align 1 + store i32 %1, ptr %param, align 1 + + call void @external_barrier(ptr %param) + + %dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4 + %set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx) + %2 = load i32, ptr %param, align 1 + store i32 %2, ptr %set_access, align 1 + ret void +} + +declare void @external_barrier(ptr) + +;--- generated-32.ll + +%"StructuredBuffer" = type { %struct.PromotedStruct } +%"RWStructuredBuffer" = type { %struct.PromotedStruct } +%struct.PromotedStruct = type { i32, i32 } + +@src = external constant %"StructuredBuffer" +@dest = external constant %"RWStructuredBuffer" + +define void @test(i32 %idx) { + %param = alloca %struct.PromotedStruct, align 1 + %src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4 + %get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx) + call void @llvm.memcpy.p0.p0.i32(ptr align 1 %param, ptr align 1 %get_access, i32 8, i1 false) + + call void @external_barrier(ptr %param) + + %dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4 + %set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx) + call void @llvm.memcpy.p0.p0.i32(ptr align 1 %set_access, ptr align 1 %param, i32 8, i1 false) + ret void +} + +declare void @external_barrier(ptr) + +;--- folded.ll + +; This tests that when there is no function call to prevent folding, the type +; will still successfully be replaced. + +; FOLDED-LABEL: @test_folded( +; FOLDED: %[[#TO:]] = load %struct.PromotedStruct, ptr %get_access +; FOLDED: store %struct.PromotedStruct %[[#TO]], ptr %set_access +; FOLDED: ret void + +%"StructuredBuffer" = type { %struct.PromotedStruct } +%"RWStructuredBuffer" = type { %struct.PromotedStruct } +%struct.PromotedStruct = type { i32, float } + +@src = external constant %"StructuredBuffer" +@dest = external constant %"RWStructuredBuffer" + +define void @test_folded(i32 %idx) { + %param = alloca %struct.PromotedStruct, align 1 + %src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4 + %get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx) + call void @llvm.memcpy.p0.p0.i32(ptr align 1 %param, ptr align 1 %get_access, i32 8, i1 false) + + %dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4 + %set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx) + call void @llvm.memcpy.p0.p0.i32(ptr align 1 %set_access, ptr align 1 %param, i32 8, i1 false) + ret void +}