Skip to content

Commit 7e3efaf

Browse files
Support scalar replacement of large structs (#6019)
* Removes size limit on ScalarReplacementPass * Adds test to verify default ScalarReplacementPass size limit --------- Co-authored-by: Zackery Mason-Blaug <zackery.mason-blaug@ntd.nintendo.com>
1 parent d3fc6ed commit 7e3efaf

File tree

6 files changed

+127
-9
lines changed

6 files changed

+127
-9
lines changed

include/spirv-tools/optimizer.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -655,7 +655,7 @@ Optimizer::PassToken CreateRedundancyEliminationPass();
655655
// element if those elements are accessed individually. The parameter is a
656656
// limit on the number of members in the composite variable that the pass will
657657
// consider replacing.
658-
Optimizer::PassToken CreateScalarReplacementPass(uint32_t size_limit = 100);
658+
Optimizer::PassToken CreateScalarReplacementPass(uint32_t size_limit = 0);
659659

660660
// Create a private to local pass.
661661
// This pass looks for variables declared in the private storage class that are

source/opt/optimizer.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) {
189189
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
190190
.RegisterPass(CreateLocalSingleStoreElimPass())
191191
.RegisterPass(CreateAggressiveDCEPass(preserve_interface))
192-
.RegisterPass(CreateScalarReplacementPass())
192+
.RegisterPass(CreateScalarReplacementPass(0))
193193
.RegisterPass(CreateLocalAccessChainConvertPass())
194194
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
195195
.RegisterPass(CreateLocalSingleStoreElimPass())
@@ -203,7 +203,7 @@ Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) {
203203
.RegisterPass(CreateRedundancyEliminationPass())
204204
.RegisterPass(CreateCombineAccessChainsPass())
205205
.RegisterPass(CreateSimplificationPass())
206-
.RegisterPass(CreateScalarReplacementPass())
206+
.RegisterPass(CreateScalarReplacementPass(0))
207207
.RegisterPass(CreateLocalAccessChainConvertPass())
208208
.RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
209209
.RegisterPass(CreateLocalSingleStoreElimPass())
@@ -401,7 +401,7 @@ bool Optimizer::RegisterPassFromFlag(const std::string& flag,
401401
RegisterPass(CreateLoopUnswitchPass());
402402
} else if (pass_name == "scalar-replacement") {
403403
if (pass_args.size() == 0) {
404-
RegisterPass(CreateScalarReplacementPass());
404+
RegisterPass(CreateScalarReplacementPass(0));
405405
} else {
406406
int limit = -1;
407407
if (pass_args.find_first_not_of("0123456789") == std::string::npos) {

source/opt/scalar_replacement_pass.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ namespace opt {
3333
// Documented in optimizer.hpp
3434
class ScalarReplacementPass : public MemPass {
3535
private:
36-
static constexpr uint32_t kDefaultLimit = 100;
36+
static constexpr uint32_t kDefaultLimit = 0;
3737

3838
public:
3939
ScalarReplacementPass(uint32_t limit = kDefaultLimit)

test/opt/optimizer_test.cpp

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,124 @@ OpFunctionEnd
388388
<< "Was expecting the result id of DebugScope to have been changed.";
389389
}
390390

391+
TEST(Optimizer, CheckDefaultPerformancePassesLargeStructScalarization) {
392+
std::string start = R"(OpCapability Shader
393+
%1 = OpExtInstImport "GLSL.std.450"
394+
OpMemoryModel Logical GLSL450
395+
OpEntryPoint Vertex %4 "main" %46 %48
396+
OpSource GLSL 430
397+
OpName %4 "main"
398+
OpDecorate %44 Block
399+
OpMemberDecorate %44 0 BuiltIn Position
400+
OpMemberDecorate %44 1 BuiltIn PointSize
401+
OpMemberDecorate %44 2 BuiltIn ClipDistance
402+
OpDecorate %48 Location 0
403+
%2 = OpTypeVoid
404+
%3 = OpTypeFunction %2
405+
%6 = OpTypeFloat 32
406+
%7 = OpTypeVector %6 4
407+
%8 = OpTypePointer Function %7
408+
%9 = OpTypeStruct %7)";
409+
410+
// add 200 float members to the struct
411+
for (int i = 0; i < 200; i++) {
412+
start += " %6";
413+
}
414+
415+
start += R"(
416+
%10 = OpTypeFunction %9 %8
417+
%14 = OpTypeFunction %6 %9
418+
%18 = OpTypePointer Function %9
419+
%20 = OpTypeInt 32 1
420+
%21 = OpConstant %20 0
421+
%24 = OpConstant %20 1
422+
%25 = OpTypeInt 32 0
423+
%26 = OpConstant %25 1
424+
%27 = OpTypePointer Function %6
425+
%43 = OpTypeArray %6 %26
426+
%44 = OpTypeStruct %7 %6 %43
427+
%45 = OpTypePointer Output %44
428+
%46 = OpVariable %45 Output
429+
%47 = OpTypePointer Input %7
430+
%48 = OpVariable %47 Input
431+
%54 = OpTypePointer Output %7
432+
%4 = OpFunction %2 None %3
433+
%5 = OpLabel
434+
%49 = OpVariable %8 Function
435+
%50 = OpLoad %7 %48
436+
OpStore %49 %50
437+
%51 = OpFunctionCall %9 %12 %49
438+
%52 = OpFunctionCall %6 %16 %51
439+
%53 = OpCompositeConstruct %7 %52 %52 %52 %52
440+
%55 = OpAccessChain %54 %46 %21
441+
OpStore %55 %53
442+
OpReturn
443+
OpFunctionEnd
444+
%12 = OpFunction %9 None %10
445+
%11 = OpFunctionParameter %8
446+
%13 = OpLabel
447+
%19 = OpVariable %18 Function
448+
%22 = OpLoad %7 %11
449+
%23 = OpAccessChain %8 %19 %21
450+
OpStore %23 %22
451+
%28 = OpAccessChain %27 %11 %26
452+
%29 = OpLoad %6 %28
453+
%30 = OpConvertFToS %20 %29
454+
%31 = OpAccessChain %27 %19 %21 %30
455+
%32 = OpLoad %6 %31
456+
%33 = OpAccessChain %27 %19 %24
457+
OpStore %33 %32
458+
%34 = OpLoad %9 %19
459+
OpReturnValue %34
460+
OpFunctionEnd
461+
%16 = OpFunction %6 None %14
462+
%15 = OpFunctionParameter %9
463+
%17 = OpLabel
464+
%37 = OpCompositeExtract %6 %15 1
465+
%38 = OpConvertFToS %20 %37
466+
%39 = OpCompositeExtract %7 %15 0
467+
%40 = OpVectorExtractDynamic %6 %39 %38
468+
OpReturnValue %40
469+
OpFunctionEnd)";
470+
471+
std::vector<uint32_t> binary;
472+
SpirvTools tools(SPV_ENV_VULKAN_1_3);
473+
tools.Assemble(start, &binary,
474+
SPV_TEXT_TO_BINARY_OPTION_PRESERVE_NUMERIC_IDS);
475+
476+
std::string test_disassembly;
477+
std::string default_disassembly;
478+
479+
{
480+
Optimizer opt(SPV_ENV_VULKAN_1_3);
481+
opt.RegisterPerformancePasses();
482+
483+
std::vector<uint32_t> optimized;
484+
ASSERT_TRUE(opt.Run(binary.data(), binary.size(), &optimized))
485+
<< start << "\n";
486+
487+
tools.Disassemble(optimized.data(), optimized.size(), &default_disassembly,
488+
SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
489+
}
490+
491+
{
492+
// default passes should not benefit from additional scalar replacement
493+
Optimizer opt(SPV_ENV_VULKAN_1_3);
494+
opt.RegisterPerformancePasses()
495+
.RegisterPass(CreateScalarReplacementPass(201))
496+
.RegisterPass(CreateAggressiveDCEPass());
497+
498+
std::vector<uint32_t> optimized;
499+
ASSERT_TRUE(opt.Run(binary.data(), binary.size(), &optimized))
500+
<< start << "\n";
501+
502+
tools.Disassemble(optimized.data(), optimized.size(), &test_disassembly,
503+
SPV_BINARY_TO_TEXT_OPTION_NO_HEADER);
504+
}
505+
506+
EXPECT_EQ(test_disassembly, default_disassembly);
507+
}
508+
391509
} // namespace
392510
} // namespace opt
393511
} // namespace spvtools

test/opt/scalar_replacement_test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ using ScalarReplacementPassName = ::testing::Test;
2727

2828
TEST_F(ScalarReplacementPassName, Default) {
2929
auto srp = ScalarReplacementPass();
30-
EXPECT_STREQ(srp.name(), "scalar-replacement=100");
30+
EXPECT_STREQ(srp.name(), "scalar-replacement=0");
3131
}
3232

3333
TEST_F(ScalarReplacementPassName, Large) {

test/tools/opt/flags.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ class TestValidPassFlags(expect.ValidObjectFile1_6,
113113
'remove-duplicates',
114114
'replace-invalid-opcode',
115115
'ssa-rewrite',
116-
'scalar-replacement=100',
116+
'scalar-replacement=0',
117117
'scalar-replacement=42',
118118
'strength-reduction',
119119
'strip-debug',
@@ -148,7 +148,7 @@ class TestPerformanceOptimizationPasses(expect.ValidObjectFile1_6,
148148
'eliminate-local-single-block',
149149
'eliminate-local-single-store',
150150
'eliminate-dead-code-aggressive',
151-
'scalar-replacement=100',
151+
'scalar-replacement=0',
152152
'convert-local-access-chains',
153153
'eliminate-local-single-block',
154154
'eliminate-local-single-store',
@@ -162,7 +162,7 @@ class TestPerformanceOptimizationPasses(expect.ValidObjectFile1_6,
162162
'redundancy-elimination',
163163
'combine-access-chains',
164164
'simplify-instructions',
165-
'scalar-replacement=100',
165+
'scalar-replacement=0',
166166
'convert-local-access-chains',
167167
'eliminate-local-single-block',
168168
'eliminate-local-single-store',

0 commit comments

Comments
 (0)