Skip to content

Commit 2eea949

Browse files
TixxxGoogle-ML-Automation
authored andcommitted
PR #21104: [NVIDIA GPU] Preserve backend config when folding transpose
Imported from GitHub PR #21104 Transpose folding pass doesn't preserve backend config when creating the new dot with transpose folded. Changing the behavior to copy the old dot's config to the new dot. Copybara import of the project: -- d2d6b62 by TJ Xu <[email protected]>: Preserve backend config when folding transpose -- 6b5fa3a by TJ Xu <[email protected]>: use SetupDerivedInstruction instead of just copying the backend config Merging this change closes #21104 COPYBARA_INTEGRATE_REVIEW=#21104 from Tixxx:tixxx/transpose_folding 6b5fa3a PiperOrigin-RevId: 715204523
1 parent 2e7cb97 commit 2eea949

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

xla/service/transpose_folding.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,10 +101,11 @@ absl::Status FoldTransposeIntoDot(InstructionOperandsPair& pair) {
101101
rhs = rhs->mutable_operand(0);
102102
}
103103
}
104-
105-
return dot->parent()->ReplaceWithNewInstruction(
106-
dot, HloInstruction::CreateDot(dot->shape(), lhs, rhs, new_dot_dims,
107-
dot->precision_config()));
104+
HloInstruction* new_dot =
105+
dot->parent()->AddInstruction(HloInstruction::CreateDot(
106+
dot->shape(), lhs, rhs, new_dot_dims, dot->precision_config()));
107+
dot->SetupDerivedInstruction(new_dot);
108+
return dot->parent()->ReplaceInstruction(dot, new_dot);
108109
}
109110

110111
// Folds the operands of `convolution` that are foldable transposes.

xla/service/transpose_folding_test.cc

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -559,5 +559,25 @@ ENTRY entry_computation {
559559
EXPECT_THAT(TransposeFolding().Run(module.get()), IsOkAndHolds(false));
560560
}
561561

562+
TEST_F(TransposeFoldingTest, FoldTransposeWithBackendConfig) {
563+
constexpr absl::string_view kHloString = R"(
564+
HloModule FoldTransposeWithBackendConfig
565+
566+
ENTRY entry_computation {
567+
x = f32[7,2,7,3]{3,2,1,0} parameter(0)
568+
y = f32[7,2,7,3]{3,2,1,0} parameter(1)
569+
transpose = f32[7,3,7,2]{3,2,1,0} transpose(y), dimensions={0,3,2,1}
570+
ROOT dot = f32[7,7,2,2]{3,2,1,0} dot(x, transpose), lhs_contracting_dims={3},
571+
rhs_contracting_dims={1}, lhs_batch_dims={0,2}, rhs_batch_dims={0,2}, backend_config={"force_earliest_schedule":false}
572+
}
573+
)";
574+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
575+
ParseAndReturnVerifiedModule(kHloString));
576+
577+
EXPECT_THAT(TransposeFolding().Run(module.get()), IsOkAndHolds(true));
578+
EXPECT_TRUE(
579+
module->entry_computation()->root_instruction()->has_backend_config());
580+
}
581+
562582
} // namespace
563583
} // namespace xla

0 commit comments

Comments
 (0)