Skip to content

Commit 8eb6944

Browse files
yliu120Google-ML-Automation
authored andcommitted
PR #23964: Cleans up call inliner in the XLA shared code path
Imported from GitHub PR #23964 1. Remove all the GPU-specific logic inside Call inliner. 2. Rewrite the IsInlineableCallOp to make the code more readable and less error-prone. The previous impl has some implicit priorities between all the checks and leads to bug. 3. Removes the test for stream annotation. Discussed with @yashk2810 , actually now we have two controls over the `compute_on` boxes from JAX. The flag in XLA `xla_gpu_experimental_stream_annotation` seems to be confusing and we should remove that. Because the control is explicitly placed in JAX, but users will get confused if there is another flag to control this. Copybara import of the project: -- 1ee6a74 by Yunlong Liu <[email protected]>: cleanup Merging this change closes #23964 COPYBARA_INTEGRATE_REVIEW=#23964 from yliu120:cleanup_call_inliner 1ee6a74 PiperOrigin-RevId: 739113714
1 parent 5e9ef96 commit 8eb6944

File tree

2 files changed

+13
-72
lines changed

2 files changed

+13
-72
lines changed

xla/service/call_inliner.cc

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -185,19 +185,6 @@ bool InlineInstruction(HloInstruction* instruction) {
185185
return true;
186186
}
187187

188-
bool InlineStreamAnnotation(HloInstruction* instruction) {
189-
if (instruction->GetModule()
190-
->config()
191-
.debug_options()
192-
.xla_gpu_experimental_stream_annotation()) {
193-
if (instruction->frontend_attributes().map().contains(
194-
kXlaStreamAnnotationAttr)) {
195-
return false;
196-
}
197-
}
198-
return true;
199-
}
200-
201188
} // namespace
202189

203190
/* static */ absl::StatusOr<CallInliner::InlinedInstructionMap>
@@ -247,12 +234,19 @@ CallInliner::Inline(HloInstruction* call) {
247234
}
248235

249236
bool CallInliner::IsInlineableCallOp(HloInstruction* instruction) const {
250-
return instruction->opcode() == HloOpcode::kCall &&
251-
!instruction->has_backend_config() &&
252-
!instruction->parent()->IsAsyncComputation() &&
253-
InlineInstruction(instruction) && InlineUnderShardy(instruction) &&
254-
InlineComposites(instruction, composites_to_preserve_) &&
255-
InlineStreamAnnotation(instruction);
237+
bool prerequisite = instruction->opcode() == HloOpcode::kCall &&
238+
!instruction->has_backend_config() &&
239+
!instruction->parent()->IsAsyncComputation();
240+
if (!prerequisite) {
241+
return false;
242+
}
243+
if (!InlineInstruction(instruction)) {
244+
// Always prioritize user's explicit requests after fulfilling the
245+
// prerequisites.
246+
return false;
247+
}
248+
return InlineUnderShardy(instruction) &&
249+
InlineComposites(instruction, composites_to_preserve_);
256250
}
257251

258252
absl::StatusOr<bool> CallInliner::Run(

xla/service/call_inliner_test.cc

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -522,59 +522,6 @@ TEST_F(CallInlinerTest, UseShardManualComputationBodySurroundedNotInlined) {
522522
"my_model.___call__.fwd.xla.sdy.manual_computation_body_14.1234");
523523
}
524524

525-
TEST_F(CallInlinerTest, DontInlineStreamAnnotationCall) {
526-
const absl::string_view hlo_string = R"(
527-
HloModule composite
528-
529-
%add (lhs: f32[]) -> f32[] {
530-
%lhs = f32[] parameter(0)
531-
%rhs = f32[] constant(2)
532-
ROOT %add = f32[] add(f32[] %lhs, f32[] %rhs)
533-
}
534-
535-
%sub (lhs: f32[]) -> f32[] {
536-
%lhs = f32[] parameter(0)
537-
%rhs = f32[] constant(1)
538-
ROOT %sub = f32[] subtract(f32[] %lhs, f32[] %rhs)
539-
}
540-
541-
ENTRY %main () -> f32[] {
542-
%lhs = f32[] constant(42)
543-
%call1 = f32[] call(f32[] %lhs), to_apply=%sub, frontend_attributes={_xla_stream_annotation="1"}
544-
ROOT %call2 = f32[] call(f32[] %call1), to_apply=%add
545-
})";
546-
547-
auto debug_options = HloTestBase::GetDebugOptionsForTest();
548-
debug_options.set_xla_gpu_experimental_stream_annotation(true);
549-
auto module = ParseAndReturnVerifiedModule(hlo_string).value();
550-
module->mutable_config().set_debug_options(debug_options);
551-
CallInliner call_inliner(/*single_call_site=*/true);
552-
553-
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
554-
absl::StatusOr<bool> filecheck_result = RunFileCheck(module->ToString({}), R"(
555-
//CHECK: %lhs.2 = f32[] constant(42)
556-
//CHECK: %call1 = f32[] call(%lhs.2), to_apply=%sub, frontend_attributes={_xla_stream_annotation="1"}
557-
//CHECK: %rhs.2 = f32[] constant(2)
558-
//CHECK: ROOT %add.1 = f32[] add(%call1, %rhs.2)
559-
)");
560-
TF_ASSERT_OK(filecheck_result.status());
561-
EXPECT_TRUE(*filecheck_result);
562-
563-
ASSERT_TRUE(mutated);
564-
ASSERT_EQ(module->entry_computation()->instruction_count(), 4);
565-
auto inst = module->entry_computation()->instructions().begin();
566-
EXPECT_THAT(*inst, op::Constant());
567-
// Check that the annotated call isn't inlined
568-
++inst;
569-
EXPECT_THAT(*inst, op::Call());
570-
571-
// Check that the non-annotated call is still inlined
572-
++inst;
573-
EXPECT_THAT(*inst, op::Constant());
574-
++inst;
575-
EXPECT_THAT(*inst, op::Add());
576-
}
577-
578525
TEST_F(CallInlinerTest, ControlDepsPropagateToRootOfInlinedInstructions) {
579526
const char* hlo = R"(
580527
HloModule test

0 commit comments

Comments
 (0)