Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ add_llvm_target(DirectXCodeGen
DXILPostOptimizationValidation.cpp
DXILPrepare.cpp
DXILPrettyPrinter.cpp
DXILRefineTypes.cpp
DXILResourceAccess.cpp
DXILResourceImplicitBinding.cpp
DXILShaderFlags.cpp
Expand Down
118 changes: 118 additions & 0 deletions llvm/lib/Target/DirectX/DXILRefineTypes.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
//===- DXILRefineTypes.cpp ----------------------------------------------===////
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "DXILRefineTypes.h"
#include "llvm/IR/DerivedTypes.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Instructions.h"
#include "llvm/IR/Intrinsics.h"
#include "llvm/IR/Module.h"

using namespace llvm;

#define DEBUG_TYPE "dxil-refine-types"

static Type *inferType(Value *Operand) {
if (auto *CI = dyn_cast<CallInst>(Operand))
if (CI->getCalledFunction()->getName().starts_with(
"llvm.dx.resource.getpointer"))
if (auto *ExtType = dyn_cast<TargetExtType>(CI->getOperand(0)->getType()))
return ExtType->getTypeParameter(0); // Valid for all dx.Types

if (auto *AI = dyn_cast<AllocaInst>(Operand))
return AI->getAllocatedType();

// TODO: Extend to other useful/applicable cases
return nullptr;
}

static Type *mergeInferredTypes(Type *A, Type *B) {
if (!A)
return B;

if (!B)
return A;

if (A == B)
return A;

// Otherwise, neither was inferred, or inferred differently
return nullptr;
}

bool DXILRefineTypesPass::runImpl(Function &F) {
// First detect the pattern: (generated from SimplifyAnyMemTransfer)
// %temp = load Type, ptr %src, ...
// store Type %temp, ptr %dest, ...
// where,
// store is the only user of %temp
// Type is either an i32 or i64
//
// We are currently only concerned with i32 and i64 as these can incidently
// promote 16/32 bit types to 32/64 bit arthimetic.
SmallVector<std::pair<LoadInst *, StoreInst *>, 4> ToVisit;
for (BasicBlock &BB : F)
for (Instruction &I : BB)
if (auto *LI = dyn_cast<LoadInst>(&I))
if (LI->hasOneUse())
if (auto *SI = dyn_cast<StoreInst>(LI->user_back()))
if (LI->getAccessType() == SI->getAccessType())
if (LI->getAccessType()->isIntegerTy(32) ||
LI->getAccessType()->isIntegerTy(64))
ToVisit.push_back({LI, SI});

bool Modified = false;
for (auto [LI, SI] : ToVisit) {
Type *LoadTy = inferType(LI->getPointerOperand());
Type *StoreTy = inferType(SI->getPointerOperand());

Type *const InferredTy = mergeInferredTypes(LoadTy, StoreTy);
if (!InferredTy || InferredTy == LI->getType())
continue; // Nothing to be done. Skip.

// Replace the type of the load/store
IRBuilder<> LoadBuilder(SI);
LoadInst *TypedLoad =
LoadBuilder.CreateLoad(InferredTy, LI->getPointerOperand());

TypedLoad->setAlignment(LI->getAlign());
TypedLoad->setVolatile(LI->isVolatile());
TypedLoad->setOrdering(LI->getOrdering());
TypedLoad->setAAMetadata(LI->getAAMetadata());
TypedLoad->copyMetadata(*LI, LLVMContext::MD_mem_parallel_loop_access);
TypedLoad->copyMetadata(*LI, LLVMContext::MD_access_group);

IRBuilder<> StoreBuilder(SI);
StoreInst *TypedStore = StoreBuilder.CreateStore(
TypedLoad, SI->getPointerOperand(), SI->isVolatile());

TypedStore->setAlignment(SI->getAlign());
TypedStore->setVolatile(SI->isVolatile());
TypedStore->setOrdering(SI->getOrdering());
TypedStore->setAAMetadata(SI->getAAMetadata());
TypedStore->copyMetadata(*SI, LLVMContext::MD_mem_parallel_loop_access);
TypedStore->copyMetadata(*SI, LLVMContext::MD_access_group);
TypedStore->copyMetadata(*SI, LLVMContext::MD_DIAssignID);

SI->eraseFromParent();
LI->eraseFromParent();

Modified = true;
}

return Modified;
}

PreservedAnalyses DXILRefineTypesPass::run(Function &F,
FunctionAnalysisManager &AM) {
if (!runImpl(F))
return PreservedAnalyses::all();

// TODO: Can probably preserve some CFG analyses
return PreservedAnalyses::none();
}
40 changes: 40 additions & 0 deletions llvm/lib/Target/DirectX/DXILRefineTypes.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
//===- DXILRefineTypes.h - Infer additional type information ----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This pass infers increased type information fidelity of memory operations and
// replaces their uses with it. Primarily load/store.
//
// This is used to prevent the propogation of introduced problematic types. For
// instance: the InstCombine pass can promote aggregates of 16/32-bit types to
// be i32/64 loads and stores.
//
//===----------------------------------------------------------------------===//

#ifndef LLVM_TRANSFORMS_SCALAR_DXILREFINETYPES_H
#define LLVM_TRANSFORMS_SCALAR_DXILREFINETYPES_H

#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/PassManager.h"

namespace llvm {

class Function;

class DXILRefineTypesPass : public PassInfoMixin<DXILRefineTypesPass> {
private:
bool runImpl(Function &F);

public:
DXILRefineTypesPass() = default;

PreservedAnalyses run(Function &F, FunctionAnalysisManager &AM);
};

} // end namespace llvm

#endif // LLVM_TRANSFORMS_SCALAR_DXILREFINETYPES_H
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/DirectXPassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ MODULE_PASS("print<dxil-root-signature>", dxil::RootSignatureAnalysisPrinter(dbg
#define FUNCTION_PASS(NAME, CREATE_PASS)
#endif
FUNCTION_PASS("dxil-forward-handle-accesses", DXILForwardHandleAccesses())
FUNCTION_PASS("dxil-refine-types", DXILRefineTypesPass())
FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
FUNCTION_PASS("dxil-legalize", DXILLegalizePass())
#undef FUNCTION_PASS
8 changes: 8 additions & 0 deletions llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "DXILOpLowering.h"
#include "DXILPostOptimizationValidation.h"
#include "DXILPrettyPrinter.h"
#include "DXILRefineTypes.h"
#include "DXILResourceAccess.h"
#include "DXILResourceImplicitBinding.h"
#include "DXILRootSignature.h"
Expand All @@ -37,10 +38,12 @@
#include "llvm/CodeGen/TargetPassConfig.h"
#include "llvm/IR/IRPrintingPasses.h"
#include "llvm/IR/LegacyPassManager.h"
#include "llvm/IR/PassManager.h"
#include "llvm/InitializePasses.h"
#include "llvm/MC/MCSectionDXContainer.h"
#include "llvm/MC/SectionKind.h"
#include "llvm/MC/TargetRegistry.h"
#include "llvm/Passes/OptimizationLevel.h"
#include "llvm/Passes/PassBuilder.h"
#include "llvm/Support/CodeGen.h"
#include "llvm/Support/Compiler.h"
Expand Down Expand Up @@ -147,6 +150,11 @@ DirectXTargetMachine::~DirectXTargetMachine() {}
void DirectXTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
#define GET_PASS_REGISTRY "DirectXPassRegistry.def"
#include "llvm/Passes/TargetPassRegistry.inc"

PB.registerPeepholeEPCallback(
[](FunctionPassManager &FPM, OptimizationLevel) {
FPM.addPass(DXILRefineTypesPass());
});
}

bool DirectXTargetMachine::addPassesToEmitFile(
Expand Down
150 changes: 150 additions & 0 deletions llvm/test/CodeGen/DirectX/refine-access-load-store.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
; RUN: split-file %s %t
; RUN: opt -passes=dxil-refine-types -S -mtriple=dxil-pc-shadermodel6.0-library < %t/explicit-64.ll | FileCheck %s
; RUN: opt -passes='instcombine,dxil-refine-types' -S -mtriple=dxil-pc-shadermodel6.0-library < %t/generated-64.ll | FileCheck %s
; RUN: opt -passes=dxil-refine-types -S -mtriple=dxil-pc-shadermodel6.0-library < %t/explicit-32.ll | FileCheck %s
; RUN: opt -passes='instcombine,dxil-refine-types' -S -mtriple=dxil-pc-shadermodel6.0-library < %t/generated-32.ll | FileCheck %s

; RUN: opt -passes='instcombine,dxil-refine-types' -S -mtriple=dxil-pc-shadermodel6.0-library < %t/folded.ll | FileCheck %s --check-prefix=FOLDED

; Tests that dxil-refine-types will catch the access pattern generated by inst-combine

; CHECK-LABEL: @test(
; CHECK: %[[#FROM:]] = load %struct.PromotedStruct, ptr %get_access
; CHECK: store %struct.PromotedStruct %[[#FROM]], ptr %param
; CHECK: call void @external_barrier(ptr{{.*}}%param)
; CHECK: %[[#TO:]] = load %struct.PromotedStruct, ptr %param
; CHECK: store %struct.PromotedStruct %[[#TO]], ptr %set_access
; CHECK: ret void

;--- explicit-64.ll

%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
%struct.PromotedStruct = type { i32, float }

@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"

define void @test(i32 %idx) {
%param = alloca %struct.PromotedStruct, align 1
%src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
%get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
%1 = load i64, ptr %get_access, align 1
store i64 %1, ptr %param, align 1

call void @external_barrier(ptr %param)

%dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
%set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
%2 = load i64, ptr %param, align 1
store i64 %2, ptr %set_access, align 1
ret void
}

declare void @external_barrier(ptr)

;--- generated-64.ll

%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
%struct.PromotedStruct = type { i32, float }

@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"

define void @test(i32 %idx) {
%param = alloca %struct.PromotedStruct, align 1
%src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
%get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
call void @llvm.memcpy.p0.p0.i32(ptr align 1 %param, ptr align 1 %get_access, i32 8, i1 false)

call void @external_barrier(ptr %param)

%dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
%set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
call void @llvm.memcpy.p0.p0.i32(ptr align 1 %set_access, ptr align 1 %param, i32 8, i1 false)
ret void
}

declare void @external_barrier(ptr)

;--- explicit-32.ll

%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
%struct.PromotedStruct = type { i16, half }

@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"

define void @test(i32 %idx) {
%param = alloca %struct.PromotedStruct, align 1
%src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
%get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
%1 = load i32, ptr %get_access, align 1
store i32 %1, ptr %param, align 1

call void @external_barrier(ptr %param)

%dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
%set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
%2 = load i32, ptr %param, align 1
store i32 %2, ptr %set_access, align 1
ret void
}

declare void @external_barrier(ptr)

;--- generated-32.ll

%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
%struct.PromotedStruct = type { i32, i32 }

@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"

define void @test(i32 %idx) {
%param = alloca %struct.PromotedStruct, align 1
%src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
%get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
call void @llvm.memcpy.p0.p0.i32(ptr align 1 %param, ptr align 1 %get_access, i32 8, i1 false)

call void @external_barrier(ptr %param)

%dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
%set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
call void @llvm.memcpy.p0.p0.i32(ptr align 1 %set_access, ptr align 1 %param, i32 8, i1 false)
ret void
}

declare void @external_barrier(ptr)

;--- folded.ll

; This tests that when there is no function call to prevent folding, the type
; will still successfully be replaced.

; FOLDED-LABEL: @test_folded(
; FOLDED: %[[#TO:]] = load %struct.PromotedStruct, ptr %get_access
; FOLDED: store %struct.PromotedStruct %[[#TO]], ptr %set_access
; FOLDED: ret void

%"StructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
%"RWStructuredBuffer<struct.PromotedStruct>" = type { %struct.PromotedStruct }
%struct.PromotedStruct = type { i32, float }

@src = external constant %"StructuredBuffer<struct.PromotedStruct>"
@dest = external constant %"RWStructuredBuffer<struct.PromotedStruct>"

define void @test_folded(i32 %idx) {
%param = alloca %struct.PromotedStruct, align 1
%src = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @src, align 4
%get_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %src, i32 %idx)
call void @llvm.memcpy.p0.p0.i32(ptr align 1 %param, ptr align 1 %get_access, i32 8, i1 false)

%dest = load target("dx.RawBuffer", %struct.PromotedStruct, 1, 0), ptr @dest, align 4
%set_access = call noundef nonnull align 1 dereferenceable(8) ptr @llvm.dx.resource.getpointer.p0.tdx.RawBuffer_s_struct.PromotedStructs_1_0t(target("dx.RawBuffer", %struct.PromotedStruct, 1, 0) %dest, i32 %idx)
call void @llvm.memcpy.p0.p0.i32(ptr align 1 %set_access, ptr align 1 %param, i32 8, i1 false)
ret void
}