Skip to content

Commit 961e5c2

Browse files
metaflowGoogle-ML-Automation
authored andcommitted
[XLA:GPU] move TransposeFolding after simplifier pipeline
Combination of DotDecompose, AlgebraicSimplifier, and TransposeFolding might never reach a fixed point and stuck in a rewrite loop. Moving TransposeFolding in simplification-2 pipeline should result in a similar output as we keep the relative order. Also: - Added a few tests to detect / record cases when we run into rewrite loop - Log warning when HloPassFix reaches iteraction limit as this is likely a similar bug PiperOrigin-RevId: 715279302
1 parent d4646a9 commit 961e5c2

File tree

3 files changed

+63
-6
lines changed

3 files changed

+63
-6
lines changed

xla/hlo/pass/hlo_pass_fix.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919
#include <algorithm>
2020
#include <type_traits>
2121

22+
#include "absl/log/log.h"
2223
#include "absl/status/statusor.h"
2324
#include "xla/hlo/ir/hlo_module.h"
2425
#include "xla/hlo/ir/hlo_module_group.h"
@@ -76,8 +77,8 @@ class HloPassFix : public Pass {
7677
VLOG(3) << "changed_this_iteration: " << changed_this_iteration;
7778
++iteration_count;
7879
if (iteration_count == kIterationLimit) {
79-
VLOG(1) << "Unexpectedly high number of iterations in HLO passes, "
80-
"exiting fixed point loop.";
80+
LOG(WARNING) << "Unexpectedly high number of iterations in HLO passes, "
81+
"exiting fixed point loop.";
8182
// Return false in case this is fixed point is nested.
8283
return false;
8384
}
@@ -98,9 +99,9 @@ class HloPassFix : public Pass {
9899
<< !run_state->changed_last_iteration.empty();
99100
run_state->IncrementIteration();
100101
if (run_state->iteration == kIterationLimit) {
101-
VLOG(1) << "Unexpectedly high number of iterations in HLO passes '"
102-
<< Pass::name() << "' for module '" << module->name()
103-
<< "'. Exiting fixed point loop.";
102+
LOG(WARNING) << "Unexpectedly high number of iterations in HLO passes '"
103+
<< Pass::name() << "' for module '" << module->name()
104+
<< "'. Exiting fixed point loop.";
104105
// Clear changed and abort in case this is fixed point is nested.
105106
run_state->changed.clear();
106107
break;

xla/service/gpu/gpu_compiler.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -821,7 +821,6 @@ absl::Status RunOptimizationPasses(
821821
pipeline.AddPass<HloConstantFolding>();
822822
pipeline.AddPass<ConditionalSimplifier>();
823823
pipeline.AddPass<RealImagExpander>();
824-
pipeline.AddPass<TransposeFolding>(CanFoldTransposeOperandIntoDot);
825824
pipeline.AddPass<HloCSE>(/*is_layout_sensitive=*/false);
826825
pipeline.AddPass<HloDCE>();
827826
}();
@@ -835,6 +834,7 @@ absl::Status RunOptimizationPasses(
835834
pipeline.AddPass<ConvertMover>();
836835
pipeline.AddPass<GpuAlgebraicSimplifier>(layout_insensitive_algsimp_opts,
837836
gpu_version);
837+
pipeline.AddPass<TransposeFolding>(CanFoldTransposeOperandIntoDot);
838838
}();
839839

840840
pipeline.AddPass<HloComputationDeduplicator>(

xla/service/gpu/gpu_compiler_test.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1568,6 +1568,62 @@ TEST_F(PassOrderTest,
15681568
"remove-no-op-reduce-precision-algebraic-simplifier");
15691569
}
15701570

1571+
// Tests that passes are converging and pipelines reach a fix point.
1572+
class FixPointTest : public HloTestBase {
1573+
public:
1574+
void ExpectPipelinesReachFixedPoint(absl::string_view module_text) {
1575+
std::unique_ptr<HloModule> optimized_module;
1576+
TF_ASSERT_OK_AND_ASSIGN(
1577+
std::unique_ptr<VerifiedHloModule> module,
1578+
ParseAndReturnVerifiedModule(module_text, GetModuleConfigForTest()));
1579+
TF_ASSERT_OK_AND_ASSIGN(optimized_module,
1580+
GetOptimizedModule(std::move(module)));
1581+
1582+
std::string last_pipeline_name;
1583+
int count = 0;
1584+
for (const HloPassMetadata& pass_metadata :
1585+
optimized_module->metadata()->proto().pass_metadata()) {
1586+
if (pass_metadata.pass_name() != "pipeline-start") {
1587+
continue;
1588+
}
1589+
VLOG(2) << "pipeline: " << pass_metadata.pipeline_name();
1590+
if (pass_metadata.pipeline_name() != last_pipeline_name) {
1591+
count = 0;
1592+
last_pipeline_name = pass_metadata.pipeline_name();
1593+
}
1594+
count++;
1595+
// 25 is a default iteration limit of HloPassFix.
1596+
EXPECT_LT(count, 25) << "Pipeline '" << pass_metadata.pipeline_name()
1597+
<< "' ran " << count
1598+
<< " times. That is likely an indication that the "
1599+
"pipeline is not reaching a fixed point.";
1600+
}
1601+
}
1602+
};
1603+
1604+
TEST_F(FixPointTest, Constant) {
1605+
ExpectPipelinesReachFixedPoint(R"(ENTRY main {
1606+
ROOT constant = f32[] constant(0)
1607+
})");
1608+
}
1609+
1610+
TEST_F(FixPointTest, ReshapeTranspose) {
1611+
ExpectPipelinesReachFixedPoint(R"(ENTRY main {
1612+
p0 = f32[1024,4096]{1,0} parameter(0)
1613+
reshape = f32[1024,1024,4]{2,1,0} reshape(p0)
1614+
ROOT transpose = f32[4,1024,1024]{2,1,0} transpose(reshape), dimensions={2,1,0}
1615+
})");
1616+
}
1617+
1618+
TEST_F(FixPointTest, DotWithBatchDims) {
1619+
// Reduced test case for b/383729716.
1620+
ExpectPipelinesReachFixedPoint(R"(ENTRY main {
1621+
p0 = f32[8,4,64]{2,1,0} parameter(0)
1622+
p1 = f32[4,64,1024] parameter(1)
1623+
ROOT dot = f32[4,8,1024]{2,1,0} dot(p0, p1), lhs_batch_dims={1}, lhs_contracting_dims={2}, rhs_batch_dims={0}, rhs_contracting_dims={1}
1624+
})");
1625+
}
1626+
15711627
} // namespace
15721628
} // namespace gpu
15731629
} // namespace xla

0 commit comments

Comments
 (0)