Skip to content

[DirectX] Start the creation of a DXIL Instruction legalizer #131221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Mar 18, 2025
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ add_llvm_target(DirectXCodeGen
DXILShaderFlags.cpp
DXILTranslateMetadata.cpp
DXILRootSignature.cpp
DXILLegalizePass.cpp

LINK_COMPONENTS
Analysis
Expand Down
212 changes: 212 additions & 0 deletions llvm/lib/Target/DirectX/DXILLegalizePass.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
//===- DXILLegalizePass.cpp - Legalizes llvm IR for DXIL ------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//
//===---------------------------------------------------------------------===//
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Doubled heading line is unncessary.

///
/// \file This file contains a pass to remove i8 truncations and i64 extract
/// and insert elements.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the only set of legalizations you expect this pass to handle? If not, this will probably get out of date quickly.

///
//===----------------------------------------------------------------------===//
#include "DXILLegalizePass.h"
#include "DirectX.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstIterator.h"
#include "llvm/IR/Instruction.h"
#include "llvm/Pass.h"
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
#include <functional>
#include <map>
#include <stack>
#include <vector>

#define DEBUG_TYPE "dxil-legalize"

using namespace llvm;
namespace {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: don't include static functions in the anonymous namespace.

Because of this, we have a simple guideline: make anonymous namespaces as small as possible, and only use them for class declarations.

see: https://llvm.org/docs/CodingStandards.html#restrict-visibility


static bool fixI8TruncUseChain(Instruction &I,
std::stack<Instruction *> &ToRemove,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm::SmallVector and llvm::SmallVectorImpl should be preferred over std::stack.

std::map<Value *, Value *> &ReplacedValues) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm::DenseMap should almost always be used instead of std::map.
see: https://llvm.org/docs/CodingStandards.html#c-standard-library


auto *Cmp = dyn_cast<CmpInst>(&I);

if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
if (Trunc->getDestTy()->isIntegerTy(8)) {
ReplacedValues[Trunc] = Trunc->getOperand(0);
ToRemove.push(Trunc);
}
} else if (I.getType()->isIntegerTy(8) ||
(Cmp && Cmp->getOperand(0)->getType()->isIntegerTy(8))) {
IRBuilder<> Builder(&I);

std::vector<Value *> NewOperands;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use llvm::SmallVector to avoid dynamic allocations.

Type *InstrType = IntegerType::get(I.getContext(), 32);
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
Value *Op = I.getOperand(OpIdx);
if (ReplacedValues.count(Op))
InstrType = ReplacedValues[Op]->getType();
Comment on lines +51 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to get conflicting InstrTypes here? Does this deserve an assert?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to say no, but not sure what you mean. I'll just explain what I was thinking here and maybe that will answer your question.

So I'm setting the default instruction type to i32 that happens on line 48. That means if we ever have an instruction and all its operands are imm 8 bits we would sext them all to 32. What the loop on line 49 does is to look at the operands and see if we have created any replacement types for them. these typically end up being the values before a trunc operation. Now we are able to set the instruction type. This is really only important for BinaryOperator instructions.

You can see the above explanations exercised in the i16_test and the all_imm of llvm/test/CodeGen/DirectX/legalize-i8.ll.

}
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
Value *Op = I.getOperand(OpIdx);
if (ReplacedValues.count(Op))
NewOperands.push_back(ReplacedValues[Op]);
else if (auto *Imm = dyn_cast<ConstantInt>(Op)) {
APInt Value = Imm->getValue();
unsigned NewBitWidth = InstrType->getIntegerBitWidth();
// Note: options here are sext or sextOrTrunc.
// Since i8 isn't supported, we assume new values
// will always have a higher bitness.
Comment on lines +62 to +63
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An assert to document/enforce the assumption wouldn't hurt

APInt NewValue = Value.sext(NewBitWidth);
NewOperands.push_back(ConstantInt::get(InstrType, NewValue));
} else {
assert(!Op->getType()->isIntegerTy(8));
NewOperands.push_back(Op);
}
}

Value *NewInst = nullptr;
if (auto *BO = dyn_cast<BinaryOperator>(&I)) {
NewInst =
Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);

if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) {
if (OBO->hasNoSignedWrap())
cast<BinaryOperator>(NewInst)->setHasNoSignedWrap();
if (OBO->hasNoUnsignedWrap())
cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap();
}
} else if (Cmp) {
NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0],
NewOperands[1]);
Cmp->replaceAllUsesWith(NewInst);
}

if (NewInst) {
ReplacedValues[&I] = NewInst;
ToRemove.push(&I);
}
} else if (auto *Cast = dyn_cast<CastInst>(&I)) {
if (Cast->getSrcTy()->isIntegerTy(8)) {
ToRemove.push(Cast);
Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]);
}
}

return !ToRemove.empty();
}

static bool
downcastI64toI32InsertExtractElements(Instruction &I,
std::stack<Instruction *> &ToRemove,
std::map<Value *, Value *> &) {

if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
Value *Idx = Extract->getIndexOperand();
auto *CI = dyn_cast<ConstantInt>(Idx);
if (CI && CI->getBitWidth() == 64) {
IRBuilder<> Builder(Extract);
int64_t IndexValue = CI->getSExtValue();
auto *Idx32 =
ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue);
Value *NewExtract = Builder.CreateExtractElement(
Extract->getVectorOperand(), Idx32, Extract->getName());

Extract->replaceAllUsesWith(NewExtract);
ToRemove.push(Extract);
}
}

if (auto *Insert = dyn_cast<InsertElementInst>(&I)) {
Value *Idx = Insert->getOperand(2);
auto *CI = dyn_cast<ConstantInt>(Idx);
if (CI && CI->getBitWidth() == 64) {
int64_t IndexValue = CI->getSExtValue();
auto *Idx32 =
ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue);
IRBuilder<> Builder(Insert);
Value *Insert32Index = Builder.CreateInsertElement(
Insert->getOperand(0), Insert->getOperand(1), Idx32,
Insert->getName());

Insert->replaceAllUsesWith(Insert32Index);
ToRemove.push(Insert);
}
}

return !ToRemove.empty();
}

class DXILLegalizationPipeline {

public:
DXILLegalizationPipeline() { initializeLegalizationPipeline(); }

bool runLegalizationPipeline(Function &F) {
std::stack<Instruction *> ToRemove;
std::map<Value *, Value *> ReplacedValues;
bool MadeChanges = false;
for (auto &I : instructions(F)) {
for (auto &LegalizationFn : LegalizationPipeline) {
MadeChanges |= LegalizationFn(I, ToRemove, ReplacedValues);
}
}
while (!ToRemove.empty()) {
Instruction *I = ToRemove.top();
I->eraseFromParent();
ToRemove.pop();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't really need to be using a stack here right? Isn't it reasonable to just iterate and erase?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do need a stack because most the i8 cases can’t use replaceAllUsesWith because the instruction types are different. My work around is to remove starting from the last instruction we saw so that there are no uses when we call eraseFromParent

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You don't need a stack for that. You can reverse iterate a SmallVector: for (auto *Inst : reverse(ToRemove))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the list size isn't huge so this probably doesn't matter, but reverse would be applying std::iter_swap to every pair of iterators thats O(n) + O(n) for erasing vs just O(n). I'm guessing the SmallVector benefits outweigh this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

llvm::reverse creates a range iterator with reverse iterators, not std::iter_swap, so it is O(n):

https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/ADT/STLExtras.h#L428

}

return MadeChanges;
}

private:
std::vector<std::function<bool(Instruction &, std::stack<Instruction *> &,
std::map<Value *, Value *> &)>>
LegalizationPipeline;

void initializeLegalizationPipeline() {
LegalizationPipeline.push_back(fixI8TruncUseChain);
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
}
};

class DXILLegalizeLegacy : public FunctionPass {

public:
bool runOnFunction(Function &F) override;
DXILLegalizeLegacy() : FunctionPass(ID) {}

static char ID; // Pass identification.
};
} // namespace

PreservedAnalyses DXILLegalizePass::run(Function &F,
FunctionAnalysisManager &FAM) {
DXILLegalizationPipeline DXLegalize;
bool MadeChanges = DXLegalize.runLegalizationPipeline(F);
if (!MadeChanges)
return PreservedAnalyses::all();
PreservedAnalyses PA;
return PA;
}

bool DXILLegalizeLegacy::runOnFunction(Function &F) {
DXILLegalizationPipeline DXLegalize;
return DXLegalize.runLegalizationPipeline(F);
}

char DXILLegalizeLegacy::ID = 0;

INITIALIZE_PASS_BEGIN(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false,
false)
INITIALIZE_PASS_END(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false,
false)

FunctionPass *llvm::createDXILLegalizeLegacyPass() {
return new DXILLegalizeLegacy();
}
22 changes: 22 additions & 0 deletions llvm/lib/Target/DirectX/DXILLegalizePass.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
//===- DXILLegalizePass.h - Legalizes llvm IR for DXIL --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===---------------------------------------------------------------------===//

#ifndef LLVM_TARGET_DIRECTX_LEGALIZE_H
#define LLVM_TARGET_DIRECTX_LEGALIZE_H

#include "llvm/IR/PassManager.h"

namespace llvm {

class DXILLegalizePass : public PassInfoMixin<DXILLegalizePass> {
public:
PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM);
};
} // namespace llvm

#endif // LLVM_TARGET_DIRECTX_LEGALIZE_H
7 changes: 7 additions & 0 deletions llvm/lib/Target/DirectX/DirectX.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,13 @@ void initializeDXILFlattenArraysLegacyPass(PassRegistry &);
/// Pass to flatten arrays into a one dimensional DXIL legal form
ModulePass *createDXILFlattenArraysLegacyPass();

/// Initializer DXIL legalizationPass
void initializeDXILLegalizeLegacyPass(PassRegistry &);

/// Pass to Legalize DXIL by remove i8 truncations and i64 insert/extract
/// elements
FunctionPass *createDXILLegalizeLegacyPass();

/// Initializer for DXILOpLowering
void initializeDXILOpLoweringLegacyPass(PassRegistry &);

Expand Down
1 change: 1 addition & 0 deletions llvm/lib/Target/DirectX/DirectXPassRegistry.def
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,5 @@ MODULE_PASS("print<dxil-root-signature>", dxil::RootSignatureAnalysisPrinter(dbg
#define FUNCTION_PASS(NAME, CREATE_PASS)
#endif
FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
FUNCTION_PASS("dxil-legalize", DXILLegalizePass())
#undef FUNCTION_PASS
3 changes: 3 additions & 0 deletions llvm/lib/Target/DirectX/DirectXTargetMachine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "DXILDataScalarization.h"
#include "DXILFlattenArrays.h"
#include "DXILIntrinsicExpansion.h"
#include "DXILLegalizePass.h"
#include "DXILOpLowering.h"
#include "DXILPrettyPrinter.h"
#include "DXILResourceAccess.h"
Expand Down Expand Up @@ -52,6 +53,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
initializeDXILDataScalarizationLegacyPass(*PR);
initializeDXILFlattenArraysLegacyPass(*PR);
initializeScalarizerLegacyPassPass(*PR);
initializeDXILLegalizeLegacyPass(*PR);
initializeDXILPrepareModulePass(*PR);
initializeEmbedDXILPassPass(*PR);
initializeWriteDXILPassPass(*PR);
Expand Down Expand Up @@ -99,6 +101,7 @@ class DirectXPassConfig : public TargetPassConfig {
ScalarizerPassOptions DxilScalarOptions;
DxilScalarOptions.ScalarizeLoadStore = true;
addPass(createScalarizerPass(DxilScalarOptions));
addPass(createDXILLegalizeLegacyPass());
addPass(createDXILTranslateMetadataLegacyPass());
addPass(createDXILOpLoweringLegacyPass());
addPass(createDXILPrepareModulePass());
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/DirectX/ResourceGlobalElimination.ll
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
; CHECK-LABEL define void @main()
define void @main() local_unnamed_addr #0 {
entry:
; DXOP: %In_h.i1 = call %dx.types.Handle @dx.op.createHandle
; DXOP: %Out_h.i2 = call %dx.types.Handle @dx.op.createHandle
; DXOP: [[In_h_i:%.*]] = call %dx.types.Handle @dx.op.createHandle
; DXOP: [[Out_h_i:%.*]] = call %dx.types.Handle @dx.op.createHandle
%In_h.i = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4f32_1_0_0t(i32 0, i32 0, i32 1, i32 0, i1 false)
store target("dx.TypedBuffer", <4 x float>, 1, 0, 0) %In_h.i, ptr @In, align 4
%Out_h.i = call target("dx.TypedBuffer", <4 x float>, 1, 0, 0) @llvm.dx.resource.handlefrombinding.tdx.TypedBuffer_v4f32_1_0_0t(i32 4, i32 1, i32 1, i32 0, i1 false)
Expand Down
24 changes: 24 additions & 0 deletions llvm/test/CodeGen/DirectX/legalize-i64-extract-insert-elements.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s

define noundef <4 x float> @float4_extract(<4 x float> noundef %a) {
entry:
; CHECK: [[ee0:%.*]] = extractelement <4 x float> %a, i32 0
; CHECK: [[ee1:%.*]] = extractelement <4 x float> %a, i32 1
; CHECK: [[ee2:%.*]] = extractelement <4 x float> %a, i32 2
; CHECK: [[ee3:%.*]] = extractelement <4 x float> %a, i32 3
; CHECK: insertelement <4 x float> poison, float [[ee0]], i32 0
; CHECK: insertelement <4 x float> %{{.*}}, float [[ee1]], i32 1
; CHECK: insertelement <4 x float> %{{.*}}, float [[ee2]], i32 2
; CHECK: insertelement <4 x float> %{{.*}}, float [[ee3]], i32 3

%a.i0 = extractelement <4 x float> %a, i64 0
%a.i1 = extractelement <4 x float> %a, i64 1
%a.i2 = extractelement <4 x float> %a, i64 2
%a.i3 = extractelement <4 x float> %a, i64 3

%.upto0 = insertelement <4 x float> poison, float %a.i0, i64 0
%.upto1 = insertelement <4 x float> %.upto0, float %a.i1, i64 1
%.upto2 = insertelement <4 x float> %.upto1, float %a.i2, i64 2
%0 = insertelement <4 x float> %.upto2, float %a.i3, i64 3
ret <4 x float> %0
}
Loading
Loading