-
Notifications
You must be signed in to change notification settings - Fork 14.8k
[DirectX] Start the creation of a DXIL Instruction legalizer #131221
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
162e8b4
679ac35
7d93231
0a561f5
64937d2
cbbdd62
f8be37f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,209 @@ | ||
//===- DXILLegalizePass.cpp - Legalizes llvm IR for DXIL ------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===---------------------------------------------------------------------===// | ||
//===---------------------------------------------------------------------===// | ||
/// | ||
/// \file This file contains a pass to remove i8 truncations and i64 extract | ||
/// and insert elements. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this the only set of legalizations you expect this pass to handle? If not, this will probably get out of date quickly. |
||
/// | ||
//===----------------------------------------------------------------------===// | ||
#include "DXILLegalizePass.h" | ||
#include "DirectX.h" | ||
#include "llvm/IR/Function.h" | ||
#include "llvm/IR/IRBuilder.h" | ||
#include "llvm/IR/InstIterator.h" | ||
#include "llvm/IR/Instruction.h" | ||
#include "llvm/Pass.h" | ||
#include "llvm/Transforms/Utils/BasicBlockUtils.h" | ||
#include <functional> | ||
#include <map> | ||
#include <stack> | ||
#include <vector> | ||
|
||
#define DEBUG_TYPE "dxil-legalize" | ||
|
||
using namespace llvm; | ||
namespace { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit: don't include
see: https://llvm.org/docs/CodingStandards.html#restrict-visibility |
||
|
||
static void fixI8TruncUseChain(Instruction &I, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I find the nesting in this function hard to follow, especially since the longest case is in the middle. Since all of the instruction types are mutually exclusive, do you think it would read better to structure it like: if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
// ...
return;
}
if (auto *Cast = dyn_cast<CastInst>(&I)) {
// ...
return;
}
auto *Cmp = dyn_cast<CmpInst>(&I);
if (!I.getType()->isIntegerTy(8) &&
!(Cmp && Cmp->getOperand(0)->getType()->isIntegerTy(8)))
// Nothing to do...
return;
// ... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think i have something cleaner let me know in the new pr. |
||
std::stack<Instruction *> &ToRemove, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. llvm::SmallVector and llvm::SmallVectorImpl should be preferred over std::stack. |
||
std::map<Value *, Value *> &ReplacedValues) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. llvm::DenseMap should almost always be used instead of std::map. |
||
|
||
auto *Cmp = dyn_cast<CmpInst>(&I); | ||
|
||
if (auto *Trunc = dyn_cast<TruncInst>(&I)) { | ||
if (Trunc->getDestTy()->isIntegerTy(8)) { | ||
ReplacedValues[Trunc] = Trunc->getOperand(0); | ||
ToRemove.push(Trunc); | ||
} | ||
} else if (I.getType()->isIntegerTy(8) || | ||
(Cmp && Cmp->getOperand(0)->getType()->isIntegerTy(8))) { | ||
IRBuilder<> Builder(&I); | ||
|
||
std::vector<Value *> NewOperands; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please use llvm::SmallVector to avoid dynamic allocations. |
||
Type *InstrType = IntegerType::get(I.getContext(), 32); | ||
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { | ||
farzonl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
Value *Op = I.getOperand(OpIdx); | ||
if (ReplacedValues.count(Op)) | ||
InstrType = ReplacedValues[Op]->getType(); | ||
Comment on lines
+51
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it possible to get conflicting InstrTypes here? Does this deserve an assert? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I want to say no, but not sure what you mean. I'll just explain what I was thinking here and maybe that will answer your question. So I'm setting the default instruction type to i32 that happens on line 48. That means if we ever have an instruction and all its operands are imm 8 bits we would sext them all to 32. What the loop on line 49 does is to look at the operands and see if we have created any replacement types for them. these typically end up being the values before a trunc operation. Now we are able to set the instruction type. This is really only important for BinaryOperator instructions. You can see the above explanations exercised in the |
||
} | ||
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) { | ||
Value *Op = I.getOperand(OpIdx); | ||
if (ReplacedValues.count(Op)) | ||
NewOperands.push_back(ReplacedValues[Op]); | ||
else if (auto *Imm = dyn_cast<ConstantInt>(Op)) { | ||
APInt Value = Imm->getValue(); | ||
unsigned NewBitWidth = InstrType->getIntegerBitWidth(); | ||
// Note: options here are sext or sextOrTrunc. | ||
// Since i8 isn't supported, we assume new values | ||
// will always have a higher bitness. | ||
Comment on lines
+62
to
+63
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An assert to document/enforce the assumption wouldn't hurt |
||
APInt NewValue = Value.sext(NewBitWidth); | ||
NewOperands.push_back(ConstantInt::get(InstrType, NewValue)); | ||
} else { | ||
assert(!Op->getType()->isIntegerTy(8)); | ||
NewOperands.push_back(Op); | ||
} | ||
} | ||
|
||
Value *NewInst = nullptr; | ||
if (auto *BO = dyn_cast<BinaryOperator>(&I)) { | ||
NewInst = | ||
Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]); | ||
|
||
if (auto *OBO = dyn_cast<OverflowingBinaryOperator>(&I)) { | ||
if (OBO->hasNoSignedWrap()) | ||
cast<BinaryOperator>(NewInst)->setHasNoSignedWrap(); | ||
if (OBO->hasNoUnsignedWrap()) | ||
cast<BinaryOperator>(NewInst)->setHasNoUnsignedWrap(); | ||
} | ||
} else if (Cmp) { | ||
NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0], | ||
NewOperands[1]); | ||
Cmp->replaceAllUsesWith(NewInst); | ||
} | ||
|
||
if (NewInst) { | ||
ReplacedValues[&I] = NewInst; | ||
ToRemove.push(&I); | ||
} | ||
} else if (auto *Cast = dyn_cast<CastInst>(&I)) { | ||
if (Cast->getSrcTy()->isIntegerTy(8)) { | ||
ToRemove.push(Cast); | ||
Cast->replaceAllUsesWith(ReplacedValues[Cast->getOperand(0)]); | ||
} | ||
} | ||
} | ||
|
||
static void | ||
downcastI64toI32InsertExtractElements(Instruction &I, | ||
std::stack<Instruction *> &ToRemove, | ||
std::map<Value *, Value *> &) { | ||
|
||
if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) { | ||
Value *Idx = Extract->getIndexOperand(); | ||
auto *CI = dyn_cast<ConstantInt>(Idx); | ||
if (CI && CI->getBitWidth() == 64) { | ||
IRBuilder<> Builder(Extract); | ||
int64_t IndexValue = CI->getSExtValue(); | ||
auto *Idx32 = | ||
ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue); | ||
Value *NewExtract = Builder.CreateExtractElement( | ||
Extract->getVectorOperand(), Idx32, Extract->getName()); | ||
|
||
Extract->replaceAllUsesWith(NewExtract); | ||
ToRemove.push(Extract); | ||
} | ||
} | ||
|
||
if (auto *Insert = dyn_cast<InsertElementInst>(&I)) { | ||
Value *Idx = Insert->getOperand(2); | ||
auto *CI = dyn_cast<ConstantInt>(Idx); | ||
if (CI && CI->getBitWidth() == 64) { | ||
int64_t IndexValue = CI->getSExtValue(); | ||
auto *Idx32 = | ||
ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue); | ||
IRBuilder<> Builder(Insert); | ||
Value *Insert32Index = Builder.CreateInsertElement( | ||
Insert->getOperand(0), Insert->getOperand(1), Idx32, | ||
Insert->getName()); | ||
|
||
Insert->replaceAllUsesWith(Insert32Index); | ||
ToRemove.push(Insert); | ||
} | ||
} | ||
} | ||
|
||
class DXILLegalizationPipeline { | ||
|
||
public: | ||
DXILLegalizationPipeline() { initializeLegalizationPipeline(); } | ||
|
||
bool runLegalizationPipeline(Function &F) { | ||
std::stack<Instruction *> ToRemove; | ||
std::map<Value *, Value *> ReplacedValues; | ||
for (auto &I : instructions(F)) { | ||
for (auto &LegalizationFn : LegalizationPipeline) { | ||
LegalizationFn(I, ToRemove, ReplacedValues); | ||
} | ||
} | ||
Comment on lines
+148
to
+152
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
bool MadeChanges = !ToRemove.empty(); | ||
|
||
while (!ToRemove.empty()) { | ||
Instruction *I = ToRemove.top(); | ||
I->eraseFromParent(); | ||
ToRemove.pop(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't really need to be using a stack here right? Isn't it reasonable to just iterate and erase? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do need a stack because most the i8 cases can’t use replaceAllUsesWith because the instruction types are different. My work around is to remove starting from the last instruction we saw so that there are no uses when we call eraseFromParent There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You don't need a stack for that. You can reverse iterate a SmallVector: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So the list size isn't huge so this probably doesn't matter, but reverse would be applying std::iter_swap to every pair of iterators thats O(n) + O(n) for erasing vs just O(n). I'm guessing the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/ADT/STLExtras.h#L428 |
||
} | ||
|
||
return MadeChanges; | ||
farzonl marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
private: | ||
std::vector<std::function<void(Instruction &, std::stack<Instruction *> &, | ||
std::map<Value *, Value *> &)>> | ||
LegalizationPipeline; | ||
|
||
void initializeLegalizationPipeline() { | ||
LegalizationPipeline.push_back(fixI8TruncUseChain); | ||
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements); | ||
} | ||
}; | ||
|
||
class DXILLegalizeLegacy : public FunctionPass { | ||
|
||
public: | ||
bool runOnFunction(Function &F) override; | ||
DXILLegalizeLegacy() : FunctionPass(ID) {} | ||
|
||
static char ID; // Pass identification. | ||
}; | ||
} // namespace | ||
|
||
PreservedAnalyses DXILLegalizePass::run(Function &F, | ||
FunctionAnalysisManager &FAM) { | ||
DXILLegalizationPipeline DXLegalize; | ||
bool MadeChanges = DXLegalize.runLegalizationPipeline(F); | ||
if (!MadeChanges) | ||
return PreservedAnalyses::all(); | ||
PreservedAnalyses PA; | ||
return PA; | ||
} | ||
|
||
bool DXILLegalizeLegacy::runOnFunction(Function &F) { | ||
DXILLegalizationPipeline DXLegalize; | ||
return DXLegalize.runLegalizationPipeline(F); | ||
} | ||
|
||
char DXILLegalizeLegacy::ID = 0; | ||
|
||
INITIALIZE_PASS_BEGIN(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false, | ||
false) | ||
INITIALIZE_PASS_END(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false, | ||
false) | ||
|
||
FunctionPass *llvm::createDXILLegalizeLegacyPass() { | ||
return new DXILLegalizeLegacy(); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
//===- DXILLegalizePass.h - Legalizes llvm IR for DXIL --------------------===// | ||
// | ||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. | ||
// See https://llvm.org/LICENSE.txt for license information. | ||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
// | ||
//===---------------------------------------------------------------------===// | ||
|
||
#ifndef LLVM_TARGET_DIRECTX_LEGALIZE_H | ||
#define LLVM_TARGET_DIRECTX_LEGALIZE_H | ||
|
||
#include "llvm/IR/PassManager.h" | ||
|
||
namespace llvm { | ||
|
||
class DXILLegalizePass : public PassInfoMixin<DXILLegalizePass> { | ||
public: | ||
PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM); | ||
}; | ||
} // namespace llvm | ||
|
||
#endif // LLVM_TARGET_DIRECTX_LEGALIZE_H |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
; RUN: opt -S -passes='dxil-legalize' -mtriple=dxil-pc-shadermodel6.3-library %s | FileCheck %s | ||
|
||
define noundef <4 x float> @float4_extract(<4 x float> noundef %a) { | ||
entry: | ||
; CHECK: [[ee0:%.*]] = extractelement <4 x float> %a, i32 0 | ||
; CHECK: [[ee1:%.*]] = extractelement <4 x float> %a, i32 1 | ||
; CHECK: [[ee2:%.*]] = extractelement <4 x float> %a, i32 2 | ||
; CHECK: [[ee3:%.*]] = extractelement <4 x float> %a, i32 3 | ||
; CHECK: insertelement <4 x float> poison, float [[ee0]], i32 0 | ||
; CHECK: insertelement <4 x float> %{{.*}}, float [[ee1]], i32 1 | ||
; CHECK: insertelement <4 x float> %{{.*}}, float [[ee2]], i32 2 | ||
; CHECK: insertelement <4 x float> %{{.*}}, float [[ee3]], i32 3 | ||
|
||
%a.i0 = extractelement <4 x float> %a, i64 0 | ||
%a.i1 = extractelement <4 x float> %a, i64 1 | ||
%a.i2 = extractelement <4 x float> %a, i64 2 | ||
%a.i3 = extractelement <4 x float> %a, i64 3 | ||
|
||
%.upto0 = insertelement <4 x float> poison, float %a.i0, i64 0 | ||
%.upto1 = insertelement <4 x float> %.upto0, float %a.i1, i64 1 | ||
%.upto2 = insertelement <4 x float> %.upto1, float %a.i2, i64 2 | ||
%0 = insertelement <4 x float> %.upto2, float %a.i3, i64 3 | ||
ret <4 x float> %0 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Doubled heading line is unncessary.