Skip to content

Commit 7c4d322

Browse files
committed
Narrow blanket SPIR-V legalization work in optimizer recipes
1 parent 8a13595 commit 7c4d322

File tree

5 files changed

+120
-95
lines changed

5 files changed

+120
-95
lines changed

include/spirv-tools/optimizer.hpp

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,13 @@ class Pass;
3232
struct DescriptorSetAndBinding;
3333
} // namespace opt
3434

35+
enum class SSARewriteMode {
36+
None,
37+
All,
38+
OpaqueOnly,
39+
SpecialTypes,
40+
};
41+
3542
// C++ interface for SPIR-V optimization functionalities. It wraps the context
3643
// (including target environment and the corresponding SPIR-V grammar) and
3744
// provides methods for registering optimization passes and optimizing.
@@ -125,6 +132,9 @@ class SPIRV_TOOLS_EXPORT Optimizer {
125132
// interface are considered live and are not eliminated.
126133
Optimizer& RegisterLegalizationPasses();
127134
Optimizer& RegisterLegalizationPasses(bool preserve_interface);
135+
Optimizer& RegisterLegalizationPasses(bool preserve_interface,
136+
bool include_loop_unroll,
137+
SSARewriteMode ssa_rewrite_mode);
128138

129139
// Register passes specified in the list of |flags|. Each flag must be a
130140
// string of a form accepted by Optimizer::FlagHasValidForm().
@@ -645,11 +655,6 @@ Optimizer::PassToken CreateLoopPeelingPass();
645655
// Works best after LICM and local multi store elimination pass.
646656
Optimizer::PassToken CreateLoopUnswitchPass();
647657

648-
// Creates a pass to legalize multidimensional arrays for Vulkan.
649-
// This pass will replace multidimensional arrays of resources with a single
650-
// dimensional array. Combine-access-chains should be run before this pass.
651-
Optimizer::PassToken CreateLegalizeMultidimArrayPass();
652-
653658
// Create global value numbering pass.
654659
// This pass will look for instructions where the same value is computed on all
655660
// paths leading to the instruction. Those instructions are deleted.
@@ -709,7 +714,8 @@ Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor = 0);
709714
// operations on SSA IDs. This allows SSA optimizers to act on these variables.
710715
// Only variables that are local to the function and of supported types are
711716
// processed (see IsSSATargetVar for details).
712-
Optimizer::PassToken CreateSSARewritePass();
717+
Optimizer::PassToken CreateSSARewritePass(
718+
SSARewriteMode mode = SSARewriteMode::All);
713719

714720
// Create pass to convert relaxed precision instructions to half precision.
715721
// This pass converts as many relaxed float32 arithmetic operations to half as

source/opt/mem_pass.cpp

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,27 @@ bool MemPass::IsBaseTargetType(const Instruction* typeInst) const {
5353
}
5454

5555
bool MemPass::IsTargetType(const Instruction* typeInst) const {
56-
if (IsBaseTargetType(typeInst)) return true;
56+
switch (ssa_rewrite_mode_) {
57+
case SSARewriteMode::None:
58+
return false;
59+
case SSARewriteMode::OpaqueOnly:
60+
if (typeInst->IsOpaqueType()) return true;
61+
break;
62+
case SSARewriteMode::SpecialTypes:
63+
if (typeInst->IsOpaqueType()) return true;
64+
switch (typeInst->opcode()) {
65+
case spv::Op::OpTypePointer:
66+
case spv::Op::OpTypeCooperativeMatrixNV:
67+
case spv::Op::OpTypeCooperativeMatrixKHR:
68+
return true;
69+
default:
70+
break;
71+
}
72+
break;
73+
case SSARewriteMode::All:
74+
if (IsBaseTargetType(typeInst)) return true;
75+
break;
76+
}
5777
if (typeInst->opcode() == spv::Op::OpTypeArray) {
5878
if (!IsTargetType(
5979
get_def_use_mgr()->GetDef(typeInst->GetSingleWordOperand(1)))) {
@@ -72,8 +92,7 @@ bool MemPass::IsTargetType(const Instruction* typeInst) const {
7292

7393
bool MemPass::IsNonPtrAccessChain(const spv::Op opcode) const {
7494
return opcode == spv::Op::OpAccessChain ||
75-
opcode == spv::Op::OpInBoundsAccessChain ||
76-
opcode == spv::Op::OpUntypedAccessChainKHR;
95+
opcode == spv::Op::OpInBoundsAccessChain;
7796
}
7897

7998
bool MemPass::IsPtr(uint32_t ptrId) {
@@ -89,14 +108,11 @@ bool MemPass::IsPtr(uint32_t ptrId) {
89108
ptrInst = get_def_use_mgr()->GetDef(varId);
90109
}
91110
const spv::Op op = ptrInst->opcode();
92-
if (op == spv::Op::OpVariable || op == spv::Op::OpUntypedVariableKHR ||
93-
IsNonPtrAccessChain(op))
94-
return true;
111+
if (op == spv::Op::OpVariable || IsNonPtrAccessChain(op)) return true;
95112
const uint32_t varTypeId = ptrInst->type_id();
96113
if (varTypeId == 0) return false;
97114
const Instruction* varTypeInst = get_def_use_mgr()->GetDef(varTypeId);
98-
return varTypeInst->opcode() == spv::Op::OpTypePointer ||
99-
varTypeInst->opcode() == spv::Op::OpTypeUntypedPointerKHR;
115+
return varTypeInst->opcode() == spv::Op::OpTypePointer;
100116
}
101117

102118
Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
@@ -106,13 +122,11 @@ Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
106122

107123
switch (ptrInst->opcode()) {
108124
case spv::Op::OpVariable:
109-
case spv::Op::OpUntypedVariableKHR:
110125
case spv::Op::OpFunctionParameter:
111126
varInst = ptrInst;
112127
break;
113128
case spv::Op::OpAccessChain:
114129
case spv::Op::OpInBoundsAccessChain:
115-
case spv::Op::OpUntypedAccessChainKHR:
116130
case spv::Op::OpPtrAccessChain:
117131
case spv::Op::OpInBoundsPtrAccessChain:
118132
case spv::Op::OpImageTexelPointer:
@@ -125,8 +139,7 @@ Instruction* MemPass::GetPtr(uint32_t ptrId, uint32_t* varId) {
125139
break;
126140
}
127141

128-
if (varInst->opcode() == spv::Op::OpVariable ||
129-
varInst->opcode() == spv::Op::OpUntypedVariableKHR) {
142+
if (varInst->opcode() == spv::Op::OpVariable) {
130143
*varId = varInst->result_id();
131144
} else {
132145
*varId = 0;
@@ -241,7 +254,8 @@ void MemPass::DCEInst(Instruction* inst,
241254
}
242255
}
243256

244-
MemPass::MemPass() {}
257+
MemPass::MemPass(SSARewriteMode ssa_rewrite_mode)
258+
: ssa_rewrite_mode_(ssa_rewrite_mode) {}
245259

246260
bool MemPass::HasOnlySupportedRefs(uint32_t varId) {
247261
return get_def_use_mgr()->WhileEachUser(varId, [this](Instruction* user) {

source/opt/mem_pass.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
#include <unordered_set>
2626
#include <utility>
2727

28+
#include "spirv-tools/optimizer.hpp"
2829
#include "source/opt/basic_block.h"
2930
#include "source/opt/def_use_manager.h"
3031
#include "source/opt/dominator_analysis.h"
@@ -68,7 +69,7 @@ class MemPass : public Pass {
6869
void CollectTargetVars(Function* func);
6970

7071
protected:
71-
MemPass();
72+
explicit MemPass(SSARewriteMode ssa_rewrite_mode = SSARewriteMode::All);
7273

7374
// Returns true if |typeInst| is a scalar type
7475
// or a vector or matrix
@@ -133,7 +134,9 @@ class MemPass : public Pass {
133134
// Cache of verified non-target vars
134135
std::unordered_set<uint32_t> seen_non_target_vars_;
135136

136-
private:
137+
private:
138+
SSARewriteMode ssa_rewrite_mode_ = SSARewriteMode::All;
139+
137140
// Return true if all uses of |varId| are only through supported reference
138141
// operations ie. loads and store. Also cache in supported_ref_vars_.
139142
// TODO(dnovillo): This function is replicated in other passes and it's

source/opt/optimizer.cpp

Lines changed: 74 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,10 @@ Optimizer& Optimizer::RegisterPass(PassToken&& p) {
120120
// The legalization problem is essentially a very general copy propagation
121121
// problem. The optimization we use are all used to either do copy propagation
122122
// or enable more copy propagation.
123-
Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface) {
124-
return
123+
Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface,
124+
bool include_loop_unroll,
125+
SSARewriteMode ssa_rewrite_mode) {
126+
auto& optimizer =
125127
// Wrap OpKill instructions so all other code can be inlined.
126128
RegisterPass(CreateWrapOpKillPass())
127129
// Remove unreachable block so that merge return works.
@@ -130,87 +132,93 @@ Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface) {
130132
.RegisterPass(CreateMergeReturnPass())
131133
// Make sure uses and definitions are in the same function.
132134
.RegisterPass(CreateInlineExhaustivePass())
133-
// Make private variable function scope
134-
.RegisterPass(CreateEliminateDeadFunctionsPass())
135-
.RegisterPass(CreatePrivateToLocalPass())
136-
// Fix up the storage classes that DXC may have purposely generated
137-
// incorrectly. All functions are inlined, and a lot of dead code has
138-
// been removed.
139-
.RegisterPass(CreateFixStorageClassPass())
140-
// Propagate the value stored to the loads in very simple cases.
141-
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
142-
.RegisterPass(CreateLocalSingleStoreElimPass())
143-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
144-
// Split up aggregates so they are easier to deal with.
145-
.RegisterPass(CreateScalarReplacementPass(0))
146-
// Remove loads and stores so everything is in intermediate values.
147-
// Takes care of copy propagation of non-members.
148-
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
149-
.RegisterPass(CreateLocalSingleStoreElimPass())
150-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
151-
.RegisterPass(CreateLocalMultiStoreElimPass())
152-
.RegisterPass(CreateCombineAccessChainsPass())
153-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
154-
.RegisterPass(CreateLegalizeMultidimArrayPass())
155-
// Propagate constants to get as many constant conditions on branches
156-
// as possible.
157-
.RegisterPass(CreateCCPPass())
158-
.RegisterPass(CreateLoopUnrollPass(true))
159-
.RegisterPass(CreateDeadBranchElimPass())
160-
// Copy propagate members. Cleans up code sequences generated by
161-
// scalar replacement. Also important for removing OpPhi nodes.
162-
.RegisterPass(CreateSimplificationPass())
163-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
164-
.RegisterPass(CreateCopyPropagateArraysPass())
165-
// May need loop unrolling here see
166-
// https://github.com/Microsoft/DirectXShaderCompiler/pull/930
167-
// Get rid of unused code that contain traces of illegal code
168-
// or unused references to unbound external objects
169-
.RegisterPass(CreateVectorDCEPass())
170-
.RegisterPass(CreateDeadInsertElimPass())
171-
.RegisterPass(CreateReduceLoadSizePass())
172-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
173-
.RegisterPass(CreateRemoveUnusedInterfaceVariablesPass())
174-
.RegisterPass(CreateInterpolateFixupPass())
175-
.RegisterPass(CreateInvocationInterlockPlacementPass())
176-
.RegisterPass(CreateOpExtInstWithForwardReferenceFixupPass());
135+
.RegisterPass(CreateEliminateDeadFunctionsPass());
136+
optimizer.RegisterPass(CreatePrivateToLocalPass());
137+
// Fix up the storage classes that DXC may have purposely generated
138+
// incorrectly. All functions are inlined, and a lot of dead code has
139+
// been removed.
140+
optimizer.RegisterPass(CreateFixStorageClassPass());
141+
// Propagate the value stored to the loads in very simple cases.
142+
optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
143+
.RegisterPass(CreateLocalSingleStoreElimPass())
144+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
145+
optimizer
146+
// Split up aggregates so they are easier to deal with.
147+
.RegisterPass(CreateScalarReplacementPass(0));
148+
// Remove loads and stores so everything is in intermediate values.
149+
// Takes care of copy propagation of non-members.
150+
optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
151+
.RegisterPass(CreateLocalSingleStoreElimPass())
152+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
153+
if (ssa_rewrite_mode != SSARewriteMode::None) {
154+
optimizer.RegisterPass(CreateSSARewritePass(ssa_rewrite_mode));
155+
}
156+
optimizer
157+
// Propagate constants to get as many constant conditions on branches
158+
// as possible.
159+
.RegisterPass(CreateCCPPass());
160+
if (include_loop_unroll) {
161+
optimizer.RegisterPass(CreateLoopUnrollPass(true));
162+
}
163+
optimizer.RegisterPass(CreateDeadBranchElimPass())
164+
// Copy propagate members. Cleans up code sequences generated by scalar
165+
// replacement. Also important for removing OpPhi nodes.
166+
.RegisterPass(CreateSimplificationPass());
167+
return optimizer
168+
// May need loop unrolling here see
169+
// https://github.com/Microsoft/DirectXShaderCompiler/pull/930
170+
// Get rid of unused code that contain traces of illegal code
171+
// or unused references to unbound external objects
172+
.RegisterPass(CreateVectorDCEPass())
173+
.RegisterPass(CreateDeadInsertElimPass())
174+
.RegisterPass(CreateReduceLoadSizePass())
175+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
176+
.RegisterPass(CreateRemoveUnusedInterfaceVariablesPass())
177+
.RegisterPass(CreateInterpolateFixupPass())
178+
.RegisterPass(CreateInvocationInterlockPlacementPass())
179+
.RegisterPass(CreateOpExtInstWithForwardReferenceFixupPass());
177180
}
178181

179182
Optimizer& Optimizer::RegisterLegalizationPasses() {
180-
return RegisterLegalizationPasses(false);
183+
return RegisterLegalizationPasses(false, true, SSARewriteMode::All);
184+
}
185+
186+
Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface) {
187+
return RegisterLegalizationPasses(preserve_interface, true,
188+
SSARewriteMode::All);
181189
}
182190

183191
Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) {
184-
return RegisterPass(CreateWrapOpKillPass())
192+
auto& optimizer = RegisterPass(CreateWrapOpKillPass())
185193
.RegisterPass(CreateDeadBranchElimPass())
186194
.RegisterPass(CreateMergeReturnPass())
187195
.RegisterPass(CreateInlineExhaustivePass())
188196
.RegisterPass(CreateEliminateDeadFunctionsPass())
189-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
190197
.RegisterPass(CreatePrivateToLocalPass())
191198
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
192199
.RegisterPass(CreateLocalSingleStoreElimPass())
193200
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
194201
.RegisterPass(CreateScalarReplacementPass(0))
195-
.RegisterPass(CreateLocalAccessChainConvertPass())
196-
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
202+
.RegisterPass(CreateLocalAccessChainConvertPass());
203+
optimizer.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
197204
.RegisterPass(CreateLocalSingleStoreElimPass())
198-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
199-
.RegisterPass(CreateLocalMultiStoreElimPass())
200-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
201-
.RegisterPass(CreateCCPPass())
202-
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
203-
.RegisterPass(CreateLoopUnrollPass(true))
204-
.RegisterPass(CreateDeadBranchElimPass())
205-
.RegisterPass(CreateRedundancyEliminationPass())
206-
.RegisterPass(CreateCombineAccessChainsPass())
205+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
206+
optimizer.RegisterPass(CreateCCPPass())
207+
.RegisterPass(CreateAggressiveDCEPass(preserve_interface));
208+
// Preserve LoopControl::Unroll in the IR instead of always materializing
209+
// it here. The optimizer-side full unroll is very costly on large modules
210+
// with many tiny [unroll]-annotated loops, while the hint remains available
211+
// to downstream consumers in the final SPIR-V.
212+
optimizer.RegisterPass(CreateDeadBranchElimPass());
213+
optimizer.RegisterPass(CreateLocalRedundancyEliminationPass());
214+
optimizer.RegisterPass(CreateCombineAccessChainsPass())
207215
.RegisterPass(CreateSimplificationPass())
208216
.RegisterPass(CreateScalarReplacementPass(0))
209217
.RegisterPass(CreateLocalAccessChainConvertPass())
210218
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
211219
.RegisterPass(CreateLocalSingleStoreElimPass())
212220
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
213-
.RegisterPass(CreateSSARewritePass())
221+
.RegisterPass(CreateSSARewritePass(SSARewriteMode::SpecialTypes))
214222
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
215223
.RegisterPass(CreateVectorDCEPass())
216224
.RegisterPass(CreateDeadInsertElimPass())
@@ -220,9 +228,9 @@ Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) {
220228
.RegisterPass(CreateCopyPropagateArraysPass())
221229
.RegisterPass(CreateReduceLoadSizePass())
222230
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
223-
.RegisterPass(CreateBlockMergePass())
224-
.RegisterPass(CreateRedundancyEliminationPass())
225-
.RegisterPass(CreateDeadBranchElimPass())
231+
.RegisterPass(CreateBlockMergePass());
232+
optimizer.RegisterPass(CreateLocalRedundancyEliminationPass());
233+
return optimizer.RegisterPass(CreateDeadBranchElimPass())
226234
.RegisterPass(CreateBlockMergePass())
227235
.RegisterPass(CreateSimplificationPass());
228236
}
@@ -401,8 +409,6 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag,
401409
RegisterPass(CreateFoldSpecConstantOpAndCompositePass());
402410
} else if (pass_name == "loop-unswitch") {
403411
RegisterPass(CreateLoopUnswitchPass());
404-
} else if (pass_name == "legalize-multidim-array") {
405-
RegisterPass(CreateLegalizeMultidimArrayPass());
406412
} else if (pass_name == "scalar-replacement") {
407413
if (pass_args.size() == 0) {
408414
RegisterPass(CreateScalarReplacementPass(0));
@@ -965,11 +971,6 @@ Optimizer::PassToken CreateLoopUnswitchPass() {
965971
MakeUnique<opt::LoopUnswitchPass>());
966972
}
967973

968-
Optimizer::PassToken CreateLegalizeMultidimArrayPass() {
969-
return MakeUnique<Optimizer::PassToken::Impl>(
970-
MakeUnique<opt::LegalizeMultidimArrayPass>());
971-
}
972-
973974
Optimizer::PassToken CreateRedundancyEliminationPass() {
974975
return MakeUnique<Optimizer::PassToken::Impl>(
975976
MakeUnique<opt::RedundancyEliminationPass>());
@@ -1019,9 +1020,9 @@ Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor) {
10191020
MakeUnique<opt::LoopUnroller>(fully_unroll, factor));
10201021
}
10211022

1022-
Optimizer::PassToken CreateSSARewritePass() {
1023+
Optimizer::PassToken CreateSSARewritePass(SSARewriteMode mode) {
10231024
return MakeUnique<Optimizer::PassToken::Impl>(
1024-
MakeUnique<opt::SSARewritePass>());
1025+
MakeUnique<opt::SSARewritePass>(mode));
10251026
}
10261027

10271028
Optimizer::PassToken CreateCopyPropagateArraysPass() {

source/opt/ssa_rewrite_pass.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -294,7 +294,8 @@ class SSARewriter {
294294

295295
class SSARewritePass : public MemPass {
296296
public:
297-
SSARewritePass() = default;
297+
explicit SSARewritePass(SSARewriteMode mode = SSARewriteMode::All)
298+
: MemPass(mode) {}
298299

299300
const char* name() const override { return "ssa-rewrite"; }
300301
Status Process() override;

0 commit comments

Comments
 (0)