Skip to content

Commit de3d5ac

Browse files
alan-bakerdneto0
andauthored
Add tooling support for SPV_KHR_maximal_reconvergence (KhronosGroup#5542)
* Validation for SPV_KHR_maximal_reconvergence * Add pass to add/remove maximal reconvergence execution mode --------- Co-authored-by: David Neto <[email protected]>
1 parent 14000ad commit de3d5ac

16 files changed

+954
-2
lines changed

Android.mk

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,7 @@ SPVTOOLS_OPT_SRC_FILES := \
157157
source/opt/loop_utils.cpp \
158158
source/opt/mem_pass.cpp \
159159
source/opt/merge_return_pass.cpp \
160+
source/opt/modify_maximal_reconvergence.cpp \
160161
source/opt/module.cpp \
161162
source/opt/optimizer.cpp \
162163
source/opt/pass.cpp \

BUILD.gn

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,8 @@ static_library("spvtools_opt") {
737737
"source/opt/mem_pass.h",
738738
"source/opt/merge_return_pass.cpp",
739739
"source/opt/merge_return_pass.h",
740+
"source/opt/modify_maximal_reconvergence.cpp",
741+
"source/opt/modify_maximal_reconvergence.h",
740742
"source/opt/module.cpp",
741743
"source/opt/module.h",
742744
"source/opt/null_pass.h",

DEPS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ vars = {
1313
'protobuf_revision': 'v21.12',
1414

1515
're2_revision': '264e71e88e1c8a4b5ec326e70e9cf1d476f58a58',
16-
'spirv_headers_revision': '7b0309708da5126b89e4ce6f19835f36dc912f2f',
16+
'spirv_headers_revision': 'ae6a8b39717523d96683bc0d20b541944e28072f',
1717
}
1818

1919
deps = {

include/spirv-tools/optimizer.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1003,6 +1003,10 @@ Optimizer::PassToken CreateSwitchDescriptorSetPass(uint32_t ds_from,
10031003
// OpBeginInterlockInvocationEXT and one OpEndInterlockInvocationEXT, in that
10041004
// order.
10051005
Optimizer::PassToken CreateInvocationInterlockPlacementPass();
1006+
1007+
// Creates a pass to add/remove maximal reconvergence execution mode.
1008+
// This pass either adds or removes maximal reconvergence from all entry points.
1009+
Optimizer::PassToken CreateModifyMaximalReconvergencePass(bool add);
10061010
} // namespace spvtools
10071011

10081012
#endif // INCLUDE_SPIRV_TOOLS_OPTIMIZER_HPP_

source/opt/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
9494
loop_unswitch_pass.h
9595
mem_pass.h
9696
merge_return_pass.h
97+
modify_maximal_reconvergence.h
9798
module.h
9899
null_pass.h
99100
passes.h
@@ -214,6 +215,7 @@ set(SPIRV_TOOLS_OPT_SOURCES
214215
loop_unswitch_pass.cpp
215216
mem_pass.cpp
216217
merge_return_pass.cpp
218+
modify_maximal_reconvergence.cpp
217219
module.cpp
218220
optimizer.cpp
219221
pass.cpp
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
// Copyright (c) 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "modify_maximal_reconvergence.h"
16+
17+
#include "source/opt/ir_context.h"
18+
#include "source/util/make_unique.h"
19+
20+
namespace spvtools {
21+
namespace opt {
22+
23+
Pass::Status ModifyMaximalReconvergence::Process() {
24+
bool changed = false;
25+
if (add_) {
26+
changed = AddMaximalReconvergence();
27+
} else {
28+
changed = RemoveMaximalReconvergence();
29+
}
30+
return changed ? Pass::Status::SuccessWithChange
31+
: Pass::Status::SuccessWithoutChange;
32+
}
33+
34+
bool ModifyMaximalReconvergence::AddMaximalReconvergence() {
35+
bool changed = false;
36+
bool has_extension = false;
37+
bool has_shader =
38+
context()->get_feature_mgr()->HasCapability(spv::Capability::Shader);
39+
for (auto extension : context()->extensions()) {
40+
if (extension.GetOperand(0).AsString() == "SPV_KHR_maximal_reconvergence") {
41+
has_extension = true;
42+
break;
43+
}
44+
}
45+
46+
std::unordered_set<uint32_t> entry_points_with_mode;
47+
for (auto mode : get_module()->execution_modes()) {
48+
if (spv::ExecutionMode(mode.GetSingleWordInOperand(1)) ==
49+
spv::ExecutionMode::MaximallyReconvergesKHR) {
50+
entry_points_with_mode.insert(mode.GetSingleWordInOperand(0));
51+
}
52+
}
53+
54+
for (auto entry_point : get_module()->entry_points()) {
55+
const uint32_t id = entry_point.GetSingleWordInOperand(1);
56+
if (!entry_points_with_mode.count(id)) {
57+
changed = true;
58+
if (!has_extension) {
59+
context()->AddExtension("SPV_KHR_maximal_reconvergence");
60+
has_extension = true;
61+
}
62+
if (!has_shader) {
63+
context()->AddCapability(spv::Capability::Shader);
64+
has_shader = true;
65+
}
66+
context()->AddExecutionMode(MakeUnique<Instruction>(
67+
context(), spv::Op::OpExecutionMode, 0, 0,
68+
std::initializer_list<Operand>{
69+
{SPV_OPERAND_TYPE_ID, {id}},
70+
{SPV_OPERAND_TYPE_EXECUTION_MODE,
71+
{static_cast<uint32_t>(
72+
spv::ExecutionMode::MaximallyReconvergesKHR)}}}));
73+
entry_points_with_mode.insert(id);
74+
}
75+
}
76+
77+
return changed;
78+
}
79+
80+
bool ModifyMaximalReconvergence::RemoveMaximalReconvergence() {
81+
bool changed = false;
82+
std::vector<Instruction*> to_remove;
83+
Instruction* mode = &*get_module()->execution_mode_begin();
84+
while (mode) {
85+
if (mode->opcode() != spv::Op::OpExecutionMode &&
86+
mode->opcode() != spv::Op::OpExecutionModeId) {
87+
break;
88+
}
89+
if (spv::ExecutionMode(mode->GetSingleWordInOperand(1)) ==
90+
spv::ExecutionMode::MaximallyReconvergesKHR) {
91+
mode = context()->KillInst(mode);
92+
changed = true;
93+
} else {
94+
mode = mode->NextNode();
95+
}
96+
}
97+
98+
changed |=
99+
context()->RemoveExtension(Extension::kSPV_KHR_maximal_reconvergence);
100+
return changed;
101+
}
102+
} // namespace opt
103+
} // namespace spvtools
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) 2024 Google LLC
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#ifndef LIBSPIRV_OPT_MODIFY_MAXIMAL_RECONVERGENCE_H_
16+
#define LIBSPIRV_OPT_MODIFY_MAXIMAL_RECONVERGENCE_H_
17+
18+
#include "pass.h"
19+
20+
namespace spvtools {
21+
namespace opt {
22+
23+
// Modifies entry points to either add or remove MaximallyReconvergesKHR
24+
//
25+
// This pass will either add or remove MaximallyReconvergesKHR to all entry
26+
// points in the module. When adding the execution mode, it does not attempt to
27+
// determine whether any ray tracing invocation repack instructions might be
28+
// executed because it is a runtime restriction. That is left to the user.
29+
class ModifyMaximalReconvergence : public Pass {
30+
public:
31+
const char* name() const override { return "modify-maximal-reconvergence"; }
32+
Status Process() override;
33+
34+
explicit ModifyMaximalReconvergence(bool add = true) : Pass(), add_(add) {}
35+
36+
IRContext::Analysis GetPreservedAnalyses() override {
37+
return IRContext::kAnalysisDefUse |
38+
IRContext::kAnalysisInstrToBlockMapping |
39+
IRContext::kAnalysisDecorations | IRContext::kAnalysisCombinators |
40+
IRContext::kAnalysisCFG | IRContext::kAnalysisNameMap |
41+
IRContext::kAnalysisConstants | IRContext::kAnalysisTypes;
42+
}
43+
44+
private:
45+
bool AddMaximalReconvergence();
46+
bool RemoveMaximalReconvergence();
47+
48+
bool add_;
49+
};
50+
} // namespace opt
51+
} // namespace spvtools
52+
53+
#endif // LIBSPIRV_OPT_MODIFY_MAXIMAL_RECONVERGENCE_H_

source/opt/optimizer.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,23 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag,
606606
return false;
607607
}
608608
RegisterPass(CreateSwitchDescriptorSetPass(from_set, to_set));
609+
} else if (pass_name == "modify-maximal-reconvergence") {
610+
if (pass_args.size() == 0) {
611+
Error(consumer(), nullptr, {},
612+
"--modify-maximal-reconvergence requires an argument");
613+
return false;
614+
}
615+
if (pass_args == "add") {
616+
RegisterPass(CreateModifyMaximalReconvergencePass(true));
617+
} else if (pass_args == "remove") {
618+
RegisterPass(CreateModifyMaximalReconvergencePass(false));
619+
} else {
620+
Errorf(consumer(), nullptr, {},
621+
"Invalid argument for --modify-maximal-reconvergence: %s (must be "
622+
"'add' or 'remove')",
623+
pass_args.c_str());
624+
return false;
625+
}
609626
} else {
610627
Errorf(consumer(), nullptr, {},
611628
"Unknown flag '--%s'. Use --help for a list of valid flags",
@@ -1141,6 +1158,11 @@ Optimizer::PassToken CreateInvocationInterlockPlacementPass() {
11411158
return MakeUnique<Optimizer::PassToken::Impl>(
11421159
MakeUnique<opt::InvocationInterlockPlacementPass>());
11431160
}
1161+
1162+
Optimizer::PassToken CreateModifyMaximalReconvergencePass(bool add) {
1163+
return MakeUnique<Optimizer::PassToken::Impl>(
1164+
MakeUnique<opt::ModifyMaximalReconvergence>(add));
1165+
}
11441166
} // namespace spvtools
11451167

11461168
extern "C" {

source/opt/passes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
#include "source/opt/loop_unroller.h"
6666
#include "source/opt/loop_unswitch_pass.h"
6767
#include "source/opt/merge_return_pass.h"
68+
#include "source/opt/modify_maximal_reconvergence.h"
6869
#include "source/opt/null_pass.h"
6970
#include "source/opt/private_to_local_pass.h"
7071
#include "source/opt/reduce_load_size.h"

source/val/validate_cfg.cpp

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ spv_result_t ValidateBranchConditional(ValidationState_t& _,
190190
"ID of an OpLabel instruction";
191191
}
192192

193+
// A similar requirement for SPV_KHR_maximal_reconvergence is deferred until
194+
// entry point call trees have been reconrded.
193195
if (_.version() >= SPV_SPIRV_VERSION_WORD(1, 6) && true_id == false_id) {
194196
return _.diag(SPV_ERROR_INVALID_ID, inst)
195197
<< "In SPIR-V 1.6 or later, True Label and False Label must be "
@@ -875,6 +877,95 @@ spv_result_t StructuredControlFlowChecks(
875877
return SPV_SUCCESS;
876878
}
877879

880+
spv_result_t MaximalReconvergenceChecks(ValidationState_t& _) {
881+
// Find all the entry points with the MaximallyReconvergencesKHR execution
882+
// mode.
883+
std::unordered_set<uint32_t> maximal_funcs;
884+
std::unordered_set<uint32_t> maximal_entry_points;
885+
for (auto entry_point : _.entry_points()) {
886+
const auto* exec_modes = _.GetExecutionModes(entry_point);
887+
if (exec_modes &&
888+
exec_modes->count(spv::ExecutionMode::MaximallyReconvergesKHR)) {
889+
maximal_entry_points.insert(entry_point);
890+
maximal_funcs.insert(entry_point);
891+
}
892+
}
893+
894+
if (maximal_entry_points.empty()) {
895+
return SPV_SUCCESS;
896+
}
897+
898+
// Find all the functions reachable from a maximal reconvergence entry point.
899+
for (const auto& func : _.functions()) {
900+
const auto& entry_points = _.EntryPointReferences(func.id());
901+
for (auto id : entry_points) {
902+
if (maximal_entry_points.count(id)) {
903+
maximal_funcs.insert(func.id());
904+
break;
905+
}
906+
}
907+
}
908+
909+
// Check for conditional branches with the same true and false targets.
910+
for (const auto& inst : _.ordered_instructions()) {
911+
if (inst.opcode() == spv::Op::OpBranchConditional) {
912+
const auto true_id = inst.GetOperandAs<uint32_t>(1);
913+
const auto false_id = inst.GetOperandAs<uint32_t>(2);
914+
if (true_id == false_id && maximal_funcs.count(inst.function()->id())) {
915+
return _.diag(SPV_ERROR_INVALID_ID, &inst)
916+
<< "In entry points using the MaximallyReconvergesKHR execution "
917+
"mode, True Label and False Label must be different labels";
918+
}
919+
}
920+
}
921+
922+
// Check for invalid multiple predecessors. Only loop headers, continue
923+
// targets, merge targets or switch targets or defaults may have multiple
924+
// unique predecessors.
925+
for (const auto& func : _.functions()) {
926+
if (!maximal_funcs.count(func.id())) continue;
927+
928+
for (const auto* block : func.ordered_blocks()) {
929+
std::unordered_set<uint32_t> unique_preds;
930+
const auto* preds = block->predecessors();
931+
if (!preds) continue;
932+
933+
for (const auto* pred : *preds) {
934+
unique_preds.insert(pred->id());
935+
}
936+
if (unique_preds.size() < 2) continue;
937+
938+
const auto* terminator = block->terminator();
939+
const auto index = terminator - &_.ordered_instructions()[0];
940+
const auto* pre_terminator = &_.ordered_instructions()[index - 1];
941+
if (pre_terminator->opcode() == spv::Op::OpLoopMerge) continue;
942+
943+
const auto* label = _.FindDef(block->id());
944+
bool ok = false;
945+
for (const auto& pair : label->uses()) {
946+
const auto* use_inst = pair.first;
947+
switch (use_inst->opcode()) {
948+
case spv::Op::OpSelectionMerge:
949+
case spv::Op::OpLoopMerge:
950+
case spv::Op::OpSwitch:
951+
ok = true;
952+
break;
953+
default:
954+
break;
955+
}
956+
}
957+
if (!ok) {
958+
return _.diag(SPV_ERROR_INVALID_CFG, label)
959+
<< "In entry points using the MaximallyReconvergesKHR "
960+
"execution mode, this basic block must not have multiple "
961+
"unique predecessors";
962+
}
963+
}
964+
}
965+
966+
return SPV_SUCCESS;
967+
}
968+
878969
spv_result_t PerformCfgChecks(ValidationState_t& _) {
879970
for (auto& function : _.functions()) {
880971
// Check all referenced blocks are defined within a function
@@ -999,6 +1090,11 @@ spv_result_t PerformCfgChecks(ValidationState_t& _) {
9991090
return error;
10001091
}
10011092
}
1093+
1094+
if (auto error = MaximalReconvergenceChecks(_)) {
1095+
return error;
1096+
}
1097+
10021098
return SPV_SUCCESS;
10031099
}
10041100

0 commit comments

Comments
 (0)