Skip to content

Commit 869808b

Browse files
Ensure only AsyncStart defines the output.
PiperOrigin-RevId: 681066552
1 parent 43a6d82 commit 869808b

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

xla/service/latency_hiding_scheduler.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,10 @@ bool InstructionDefinesValue(const HloInstruction* instruction,
8585
}
8686
// Also check if the instruction is a call to a computation that defines the
8787
// value. This is needed in cases, e.g., where we wrap a value-defining
88-
// instruction in a async call for offloading, and the async call itself will
88+
// instruction in a async call for offloading, and the async start itself will
8989
// effectively define the value in the current scope that the scheduler is
9090
// running in.
91-
if (instruction->opcode() == HloOpcode::kAsyncStart ||
92-
instruction->opcode() == HloOpcode::kAsyncDone) {
91+
if (instruction->opcode() == HloOpcode::kAsyncStart) {
9392
if (instruction->async_wrapped_opcode() == HloOpcode::kCall) {
9493
return instruction->async_wrapped_instruction()
9594
->called_computations()[0]
@@ -114,8 +113,7 @@ bool InstructionFirstDefinesBuffer(
114113
}
115114
// Similar to logic above, also check if the instruction is a call to a
116115
// computation that defines the value.
117-
if (instruction->opcode() == HloOpcode::kAsyncStart ||
118-
instruction->opcode() == HloOpcode::kAsyncDone) {
116+
if (instruction->opcode() == HloOpcode::kAsyncStart) {
119117
if (instruction->async_wrapped_opcode() == HloOpcode::kCall) {
120118
return instruction->async_wrapped_instruction()
121119
->called_computations()[0]

0 commit comments

Comments
 (0)