Skip to content

Commit 9fd631a

Browse files
committed
modify the i8 legalization pass to be a more generic legalization so we can reduce i64 insert/extracts to i32
1 parent 34939c8 commit 9fd631a

12 files changed

+289
-171
lines changed

llvm/lib/Target/DirectX/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ add_llvm_target(DirectXCodeGen
3232
DXILShaderFlags.cpp
3333
DXILTranslateMetadata.cpp
3434
DXILRootSignature.cpp
35-
LegalizeI8Pass.cpp
35+
DXILLegalizePass.cpp
3636

3737
LINK_COMPONENTS
3838
Analysis
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
//===- DXILLegalizePass.cpp - Legalizes llvm IR for DXIL-*- C++----------*-===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===---------------------------------------------------------------------===//
8+
//===---------------------------------------------------------------------===//
9+
///
10+
/// \file This file contains a pass to remove i8 truncations and i64 extract
11+
/// and insert elements.
12+
///
13+
//===----------------------------------------------------------------------===//
14+
#include "DXILLegalizePass.h"
15+
#include "DirectX.h"
16+
#include "llvm/IR/Function.h"
17+
#include "llvm/IR/IRBuilder.h"
18+
#include "llvm/IR/InstIterator.h"
19+
#include "llvm/IR/Instruction.h"
20+
#include "llvm/Pass.h"
21+
#include "llvm/Transforms/Utils/BasicBlockUtils.h"
22+
#include <functional>
23+
#include <map>
24+
#include <stack>
25+
#include <vector>
26+
27+
#define DEBUG_TYPE "dxil-legalize"
28+
29+
using namespace llvm;
30+
namespace {
31+
32+
static bool fixI8TruncUseChain(Instruction &I,
33+
std::stack<Instruction *> &ToRemove,
34+
std::map<Value *, Value *> &ReplacedValues) {
35+
36+
if (auto *Trunc = dyn_cast<TruncInst>(&I)) {
37+
if (Trunc->getDestTy()->isIntegerTy(8)) {
38+
ReplacedValues[Trunc] = Trunc->getOperand(0);
39+
ToRemove.push(Trunc);
40+
}
41+
} else if (I.getType()->isIntegerTy(8)) {
42+
IRBuilder<> Builder(&I);
43+
44+
std::vector<Value *> NewOperands;
45+
Type *InstrType = nullptr;
46+
for (unsigned OpIdx = 0; OpIdx < I.getNumOperands(); ++OpIdx) {
47+
Value *Op = I.getOperand(OpIdx);
48+
if (ReplacedValues.count(Op)) {
49+
InstrType = ReplacedValues[Op]->getType();
50+
NewOperands.push_back(ReplacedValues[Op]);
51+
} else if (auto *Imm = dyn_cast<ConstantInt>(Op)) {
52+
APInt Value = Imm->getValue();
53+
unsigned NewBitWidth = InstrType->getIntegerBitWidth();
54+
// Note: options here are sext or sextOrTrunc.
55+
// Since i8 isn't suppport we assume new values
56+
// will always have a higher bitness.
57+
APInt NewValue = Value.sext(NewBitWidth);
58+
NewOperands.push_back(ConstantInt::get(InstrType, NewValue));
59+
} else {
60+
assert(!Op->getType()->isIntegerTy(8));
61+
NewOperands.push_back(Op);
62+
}
63+
}
64+
65+
Value *NewInst = nullptr;
66+
if (auto *BO = dyn_cast<BinaryOperator>(&I))
67+
NewInst =
68+
Builder.CreateBinOp(BO->getOpcode(), NewOperands[0], NewOperands[1]);
69+
else if (auto *Cmp = dyn_cast<CmpInst>(&I))
70+
NewInst = Builder.CreateCmp(Cmp->getPredicate(), NewOperands[0],
71+
NewOperands[1]);
72+
else if (auto *Cast = dyn_cast<CastInst>(&I))
73+
NewInst = Builder.CreateCast(Cast->getOpcode(), NewOperands[0],
74+
Cast->getDestTy());
75+
else if (auto *UnaryOp = dyn_cast<UnaryOperator>(&I))
76+
NewInst = Builder.CreateUnOp(UnaryOp->getOpcode(), NewOperands[0]);
77+
78+
if (NewInst) {
79+
ReplacedValues[&I] = NewInst;
80+
ToRemove.push(&I);
81+
}
82+
} else if (auto *Sext = dyn_cast<SExtInst>(&I)) {
83+
if (Sext->getSrcTy()->isIntegerTy(8)) {
84+
ToRemove.push(Sext);
85+
Sext->replaceAllUsesWith(ReplacedValues[Sext->getOperand(0)]);
86+
}
87+
}
88+
89+
return !ToRemove.empty();
90+
}
91+
92+
static bool downcastI64toI32InsertExtractElements(
93+
Instruction &I, std::stack<Instruction *> &ToRemove, std::map<Value *, Value *> &) {
94+
95+
if (auto *Extract = dyn_cast<ExtractElementInst>(&I)) {
96+
Value *Idx = Extract->getIndexOperand();
97+
auto *CI = dyn_cast<ConstantInt>(Idx);
98+
if (CI && CI->getBitWidth() == 64) {
99+
IRBuilder<> Builder(Extract);
100+
int64_t IndexValue = CI->getSExtValue();
101+
auto *Idx32 =
102+
ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue);
103+
Value *NewExtract =
104+
Builder.CreateExtractElement(Extract->getVectorOperand(), Idx32,Extract->getName());
105+
106+
Extract->replaceAllUsesWith(NewExtract);
107+
ToRemove.push(Extract);
108+
}
109+
}
110+
111+
if (auto *Insert = dyn_cast<InsertElementInst>(&I)) {
112+
Value *Idx = Insert->getOperand(2);
113+
auto *CI = dyn_cast<ConstantInt>(Idx);
114+
if (CI && CI->getBitWidth() == 64) {
115+
int64_t IndexValue = CI->getSExtValue();
116+
auto *Idx32 =
117+
ConstantInt::get(Type::getInt32Ty(I.getContext()), IndexValue);
118+
IRBuilder<> Builder(Insert);
119+
Value *Insert32Index = Builder.CreateInsertElement(
120+
Insert->getOperand(0), Insert->getOperand(1), Idx32, Insert->getName());
121+
122+
Insert->replaceAllUsesWith(Insert32Index);
123+
ToRemove.push(Insert);
124+
}
125+
}
126+
127+
return !ToRemove.empty();
128+
}
129+
130+
class DXILLegalizationPipeline {
131+
132+
public:
133+
DXILLegalizationPipeline() { initializeLegalizationPipeline(); }
134+
135+
bool runLegalizationPipeline(Function &F) {
136+
std::stack<Instruction *> ToRemove;
137+
std::map<Value *, Value *> ReplacedValues;
138+
bool MadeChanges = false;
139+
for (auto &I : instructions(F)) {
140+
for (auto &LegalizationFn : LegalizationPipeline) {
141+
MadeChanges = LegalizationFn(I, ToRemove, ReplacedValues);
142+
}
143+
}
144+
while (!ToRemove.empty()) {
145+
Instruction *I = ToRemove.top();
146+
I->eraseFromParent();
147+
ToRemove.pop();
148+
}
149+
150+
return MadeChanges;
151+
}
152+
153+
private:
154+
std::vector<std::function<bool(Instruction &, std::stack<Instruction *> &,
155+
std::map<Value *, Value *> &)>>
156+
LegalizationPipeline;
157+
158+
void initializeLegalizationPipeline() {
159+
LegalizationPipeline.push_back(fixI8TruncUseChain);
160+
LegalizationPipeline.push_back(downcastI64toI32InsertExtractElements);
161+
}
162+
};
163+
164+
class DXILLegalizeLegacy : public FunctionPass {
165+
166+
public:
167+
bool runOnFunction(Function &F) override;
168+
DXILLegalizeLegacy() : FunctionPass(ID) {}
169+
170+
static char ID; // Pass identification.
171+
};
172+
} // namespace
173+
174+
PreservedAnalyses DXILLegalizePass::run(Function &F,
175+
FunctionAnalysisManager &FAM) {
176+
DXILLegalizationPipeline DXLegalize;
177+
bool MadeChanges = DXLegalize.runLegalizationPipeline(F);
178+
if (!MadeChanges)
179+
return PreservedAnalyses::all();
180+
PreservedAnalyses PA;
181+
return PA;
182+
}
183+
184+
bool DXILLegalizeLegacy::runOnFunction(Function &F) {
185+
DXILLegalizationPipeline DXLegalize;
186+
return DXLegalize.runLegalizationPipeline(F);
187+
}
188+
189+
char DXILLegalizeLegacy::ID = 0;
190+
191+
INITIALIZE_PASS_BEGIN(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false,
192+
false)
193+
INITIALIZE_PASS_END(DXILLegalizeLegacy, DEBUG_TYPE, "DXIL Legalizer", false,
194+
false)
195+
196+
FunctionPass *llvm::createDXILLegalizeLegacyPass() {
197+
return new DXILLegalizeLegacy();
198+
}
Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,22 @@
1-
//===- LegalizeI8Pass.h - A pass that reverts i8 conversions-*- C++ -----*-===//
1+
//===- DXILLegalizePass.h - Legalizes llvm IR for DXIL-*- C++------------*-===//
22
//
33
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
44
// See https://llvm.org/LICENSE.txt for license information.
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===---------------------------------------------------------------------===//
88

9-
#ifndef LLVM_TARGET_DIRECTX_LEGALIZEI8_H
10-
#define LLVM_TARGET_DIRECTX_LEGALIZEI8_H
9+
#ifndef LLVM_TARGET_DIRECTX_LEGALIZE_H
10+
#define LLVM_TARGET_DIRECTX_LEGALIZE_H
1111

1212
#include "llvm/IR/PassManager.h"
1313

1414
namespace llvm {
1515

16-
/// A pass that transforms multidimensional arrays into one-dimensional arrays.
17-
class LegalizeI8Pass : public PassInfoMixin<LegalizeI8Pass> {
16+
class DXILLegalizePass : public PassInfoMixin<DXILLegalizePass> {
1817
public:
1918
PreservedAnalyses run(Function &F, FunctionAnalysisManager &FAM);
2019
};
2120
} // namespace llvm
2221

23-
#endif // LLVM_TARGET_DIRECTX_LEGALIZEI8_H
22+
#endif // LLVM_TARGET_DIRECTX_LEGALIZE_H

llvm/lib/Target/DirectX/DirectX.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,12 @@ void initializeDXILFlattenArraysLegacyPass(PassRegistry &);
4747
/// Pass to flatten arrays into a one dimensional DXIL legal form
4848
ModulePass *createDXILFlattenArraysLegacyPass();
4949

50-
/// Initializer I8 legalizationPass
51-
void initializeLegalizeI8LegacyPass(PassRegistry &);
50+
/// Initializer DXIL legalizationPass
51+
void initializeDXILLegalizeLegacyPass(PassRegistry &);
5252

53-
/// Pass to remove i8 truncations
54-
FunctionPass *createLegalizeI8LegacyPass();
53+
/// Pass to Legalize DXIL by remove i8 truncations and i64 insert/extract
54+
/// elements
55+
FunctionPass *createDXILLegalizeLegacyPass();
5556

5657
/// Initializer for DXILOpLowering
5758
void initializeDXILOpLoweringLegacyPass(PassRegistry &);

llvm/lib/Target/DirectX/DirectXPassRegistry.def

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,5 @@ MODULE_PASS("print<dxil-root-signature>", dxil::RootSignatureAnalysisPrinter(dbg
3838
#define FUNCTION_PASS(NAME, CREATE_PASS)
3939
#endif
4040
FUNCTION_PASS("dxil-resource-access", DXILResourceAccess())
41-
FUNCTION_PASS("dxil-legalize-i8", LegalizeI8Pass())
41+
FUNCTION_PASS("dxil-legalize", DXILLegalizePass())
4242
#undef FUNCTION_PASS

llvm/lib/Target/DirectX/DirectXTargetMachine.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "DXILDataScalarization.h"
1616
#include "DXILFlattenArrays.h"
1717
#include "DXILIntrinsicExpansion.h"
18+
#include "DXILLegalizePass.h"
1819
#include "DXILOpLowering.h"
1920
#include "DXILPrettyPrinter.h"
2021
#include "DXILResourceAccess.h"
@@ -25,7 +26,6 @@
2526
#include "DirectX.h"
2627
#include "DirectXSubtarget.h"
2728
#include "DirectXTargetTransformInfo.h"
28-
#include "LegalizeI8Pass.h"
2929
#include "TargetInfo/DirectXTargetInfo.h"
3030
#include "llvm/CodeGen/MachineModuleInfo.h"
3131
#include "llvm/CodeGen/Passes.h"
@@ -53,7 +53,7 @@ extern "C" LLVM_EXTERNAL_VISIBILITY void LLVMInitializeDirectXTarget() {
5353
initializeDXILDataScalarizationLegacyPass(*PR);
5454
initializeDXILFlattenArraysLegacyPass(*PR);
5555
initializeScalarizerLegacyPassPass(*PR);
56-
initializeLegalizeI8LegacyPass(*PR);
56+
initializeDXILLegalizeLegacyPass(*PR);
5757
initializeDXILPrepareModulePass(*PR);
5858
initializeEmbedDXILPassPass(*PR);
5959
initializeWriteDXILPassPass(*PR);
@@ -101,8 +101,8 @@ class DirectXPassConfig : public TargetPassConfig {
101101
ScalarizerPassOptions DxilScalarOptions;
102102
DxilScalarOptions.ScalarizeLoadStore = true;
103103
addPass(createScalarizerPass(DxilScalarOptions));
104+
addPass(createDXILLegalizeLegacyPass());
104105
addPass(createDXILTranslateMetadataLegacyPass());
105-
addPass(createLegalizeI8LegacyPass());
106106
addPass(createDXILOpLoweringLegacyPass());
107107
addPass(createDXILPrepareModulePass());
108108
}

0 commit comments

Comments
 (0)