Skip to content

Commit 7a28b74

Browse files
junwhanahnGoogle-ML-Automation
authored andcommitted
Use an unbounded thread pool for StreamExecutor PjRt clients
If `tsl::thread::ThreadPool` runs out of threads, it runs the callback inline on the thread that called `Schedule()`. This pattern is known to be prone to deadlocks (e.g., if `Schedule` is called while holding a mutex that the callable needs to acquire, it will deadlock) and may harm the asynchrony in the tail case, so switching to an unbounded thread pool based on `tsl::UnboundedWorkQueue`. Since `tsl::thread::ThreadPool` is not extensible, this requires replacing the executor type from `tsl::thread::ThreadPool` to `AsyncWorkRunner`. The existing thread pool is still kept because `ExecutableBuildOptions::set_compile_thread_pool()` requires `tsl::thread::ThreadPool`. Reverts 485e912 PiperOrigin-RevId: 853019406
1 parent c8276b6 commit 7a28b74

22 files changed

+271
-124
lines changed

xla/backends/cpu/nanort/ifrt_client.cc

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ limitations under the License.
2424
#include <functional>
2525
#include <iterator>
2626
#include <memory>
27+
#include <new>
2728
#include <optional>
2829
#include <string>
2930
#include <utility>
@@ -48,6 +49,7 @@ limitations under the License.
4849
#include "xla/backends/cpu/alignment.h"
4950
#include "xla/backends/cpu/nanort/nanort_executable.h"
5051
#include "xla/hlo/builder/xla_computation.h"
52+
#include "xla/hlo/ir/hlo_input_output_alias_config.h"
5153
#include "xla/hlo/ir/hlo_module.h"
5254
#include "xla/hlo/ir/hlo_sharding.h"
5355
#include "xla/layout.h"
@@ -828,6 +830,8 @@ class NanoExecutable final
828830
TF_ASSIGN_OR_RETURN(auto nano_executable,
829831
client->nano_client()->Compile(computation));
830832

833+
TF_ASSIGN_OR_RETURN(auto donatable_input_indices,
834+
GetDonatableInputIndices(computation));
831835
TF_ASSIGN_OR_RETURN(auto program_shape, computation.GetProgramShape());
832836
TF_ASSIGN_OR_RETURN(auto proto_input_shardings,
833837
GetInputShardings(program_shape, computation));
@@ -840,8 +844,8 @@ class NanoExecutable final
840844

841845
return absl::WrapUnique(new NanoExecutable(
842846
client, std::move(computation), std::move(program_shape),
843-
std::move(nano_executable), std::move(input_shardings),
844-
std::move(output_shardings)));
847+
std::move(nano_executable), std::move(donatable_input_indices),
848+
std::move(input_shardings), std::move(output_shardings)));
845849
}
846850

847851
ifrt::Client* client() const override { return client_; }
@@ -850,8 +854,7 @@ class NanoExecutable final
850854

851855
absl::StatusOr<absl::Span<const int>> GetDonatableInputIndices()
852856
const override {
853-
return absl::UnimplementedError(
854-
"NanoExecutable::GetDonatableInputIndices is not implemented.");
857+
return donatable_input_indices_;
855858
}
856859

857860
absl::StatusOr<ExecuteResult> Execute(
@@ -1026,13 +1029,15 @@ class NanoExecutable final
10261029
NanoExecutable(NanoIfrtClient* client, XlaComputation program,
10271030
ProgramShape program_shape,
10281031
std::unique_ptr<NanoRtExecutable> executable,
1032+
std::vector<int> donatable_input_indices,
10291033
std::vector<ifrt::ShardingRef> input_shardings,
10301034
std::vector<ifrt::ShardingRef> output_shardings)
10311035
: client_(client),
10321036
devices_(ifrt::BasicDeviceList::Create(client->devices())),
10331037
program_(std::move(program)),
10341038
program_shape_(std::move(program_shape)),
10351039
executable_(std::move(executable)),
1040+
donatable_input_indices_(std::move(donatable_input_indices)),
10361041
input_shardings_(std::move(input_shardings)),
10371042
output_shardings_(std::move(output_shardings)),
10381043
user_context_(xla::ifrt::UserContextScope::current()) {}
@@ -1068,6 +1073,29 @@ class NanoExecutable final
10681073
return result;
10691074
}
10701075

1076+
// Returns a list of donatable input indices from the given HLO modules.
1077+
static absl::StatusOr<std::vector<int>> GetDonatableInputIndices(
1078+
const XlaComputation& xla_computation) {
1079+
const HloModuleProto& hlo_module_proto = xla_computation.proto();
1080+
std::vector<int> donatable_input_indices;
1081+
for (const auto& alias : hlo_module_proto.input_output_alias().entries()) {
1082+
if (alias.parameter_shape_index().empty()) {
1083+
donatable_input_indices.push_back(alias.parameter_number());
1084+
} else {
1085+
donatable_input_indices.push_back(alias.parameter_shape_index(0));
1086+
}
1087+
}
1088+
for (const auto& buffer_donor : hlo_module_proto.buffer_donor().entries()) {
1089+
if (buffer_donor.parameter_shape_index().empty()) {
1090+
donatable_input_indices.push_back(buffer_donor.parameter_number());
1091+
} else {
1092+
donatable_input_indices.push_back(
1093+
buffer_donor.parameter_shape_index(0));
1094+
}
1095+
}
1096+
return donatable_input_indices;
1097+
}
1098+
10711099
static absl::StatusOr<std::vector<OpSharding>> GetInputShardings(
10721100
const ProgramShape& program_shape, const XlaComputation& computation) {
10731101
std::vector<OpSharding> shardings(program_shape.parameters().size());
@@ -1176,6 +1204,7 @@ class NanoExecutable final
11761204
XlaComputation program_;
11771205
ProgramShape program_shape_;
11781206
std::unique_ptr<NanoRtExecutable> executable_;
1207+
std::vector<int> donatable_input_indices_;
11791208
std::vector<ifrt::ShardingRef> input_shardings_;
11801209
std::vector<ifrt::ShardingRef> output_shardings_;
11811210
const xla::ifrt::UserContextRef user_context_;

xla/pjrt/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ cc_library(
219219
"local_device_state.h",
220220
],
221221
deps = [
222+
":async_work_runner",
222223
":event_pool",
223224
":pjrt_common",
224225
":semaphore",
@@ -676,6 +677,7 @@ cc_library(
676677
visibility = internal_visibility(["//xla:friends"]),
677678
deps = [
678679
":abstract_tracked_device_buffer",
680+
":async_work_runner",
679681
":common_pjrt_client",
680682
":device_event",
681683
":event_pool",
@@ -1263,10 +1265,12 @@ cc_library(
12631265

12641266
cc_library(
12651267
name = "async_work_runner",
1268+
srcs = ["async_work_runner.cc"],
12661269
hdrs = ["async_work_runner.h"],
12671270
visibility = internal_visibility([":friends"]),
12681271
deps = [
12691272
"//xla/tsl/concurrency:async_value",
1273+
"//xla/tsl/concurrency:executor",
12701274
"//xla/tsl/concurrency:ref_count",
12711275
"@com_google_absl//absl/functional:any_invocable",
12721276
"@com_google_absl//absl/types:span",
@@ -1284,6 +1288,7 @@ cc_library(
12841288
"//xla/tsl/platform:env",
12851289
"@com_google_absl//absl/functional:any_invocable",
12861290
"@com_google_absl//absl/types:span",
1291+
"@tsl//tsl/platform:unbounded_work_queue",
12871292
],
12881293
)
12891294

xla/pjrt/async_work_runner.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/* Copyright 2026 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/pjrt/async_work_runner.h"
17+
18+
#include <memory>
19+
#include <utility>
20+
21+
#include "xla/tsl/concurrency/executor.h"
22+
23+
namespace xla {
24+
25+
namespace {
26+
27+
class AsyncWorkRunnerExecutor : public tsl::Executor {
28+
public:
29+
explicit AsyncWorkRunnerExecutor(AsyncWorkRunner* runner) : runner_(runner) {}
30+
31+
void Execute(Task task) override { runner_->Schedule(std::move(task)); }
32+
33+
private:
34+
AsyncWorkRunner* const runner_;
35+
};
36+
37+
} // namespace
38+
39+
AsyncWorkRunner::AsyncWorkRunner()
40+
: executor_(std::make_unique<AsyncWorkRunnerExecutor>(this)) {}
41+
42+
} // namespace xla

xla/pjrt/async_work_runner.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@ limitations under the License.
1616
#ifndef XLA_PJRT_ASYNC_WORK_RUNNER_H_
1717
#define XLA_PJRT_ASYNC_WORK_RUNNER_H_
1818

19+
#include <memory>
20+
1921
#include "absl/functional/any_invocable.h"
2022
#include "absl/types/span.h"
2123
#include "xla/tsl/concurrency/async_value.h"
24+
#include "xla/tsl/concurrency/executor.h"
2225
#include "xla/tsl/concurrency/ref_count.h"
2326

2427
namespace xla {
@@ -27,13 +30,22 @@ namespace xla {
2730
// pool (or concurrent work queue).
2831
class AsyncWorkRunner {
2932
public:
33+
AsyncWorkRunner();
3034
virtual ~AsyncWorkRunner() = default;
3135

3236
// `work` euqueued by `Schedule` may run on the calling thread.
3337
virtual void Schedule(absl::AnyInvocable<void() &&> work) = 0;
3438
virtual void ScheduleWhenReady(
3539
absl::Span<const tsl::RCReference<tsl::AsyncValue>> values,
3640
absl::AnyInvocable<void() &&> work) = 0;
41+
42+
// Returns an tsl::Executor implementation that is backed by this async work
43+
// runner. The returned executor is owned by the async work runner and its
44+
// lifetime is bound to the lifetime of the thread pool itself.
45+
virtual tsl::Executor& AsExecutor() { return *executor_; }
46+
47+
private:
48+
std::unique_ptr<tsl::Executor> executor_;
3749
};
3850

3951
} // namespace xla

xla/pjrt/buffer_sequencing_event.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ void BufferSequencingEvent::ExecuteOrAddToFutureTasks(
133133
// Execute the `task` when definition event becomes available. If it's already
134134
// available, the task will be executed immediately.
135135
event_.AndThen([this, traced_task = std::move(traced_task)]() mutable {
136-
thread_pool_->Schedule(std::move(traced_task));
136+
async_work_runner_->Schedule(std::move(traced_task));
137137
});
138138
}
139139

xla/pjrt/buffer_sequencing_event.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ limitations under the License.
2626
#include "absl/log/log.h"
2727
#include "absl/status/status.h"
2828
#include "absl/synchronization/mutex.h"
29+
#include "xla/pjrt/async_work_runner.h"
2930
#include "xla/pjrt/event_pool.h"
3031
#include "xla/stream_executor/stream.h"
3132
#include "xla/tsl/concurrency/async_value.h"
3233
#include "xla/tsl/concurrency/async_value_ref.h"
33-
#include "xla/tsl/platform/threadpool.h"
3434

3535
namespace xla {
3636

@@ -71,18 +71,18 @@ class BufferSequencingEvent : tsl::AsyncPayload::KeepOnError {
7171
se::Stream* definition_stream;
7272
};
7373

74-
explicit BufferSequencingEvent(tsl::thread::ThreadPool* thread_pool)
75-
: thread_pool_(thread_pool),
74+
explicit BufferSequencingEvent(AsyncWorkRunner* async_work_runner)
75+
: async_work_runner_(async_work_runner),
7676
event_(tsl::MakeUnconstructedAsyncValueRef<EventState>()) {}
7777

78-
explicit BufferSequencingEvent(tsl::thread::ThreadPool* thread_pool,
78+
explicit BufferSequencingEvent(AsyncWorkRunner* async_work_runner,
7979
tsl::AsyncValueRef<EventState> event)
80-
: thread_pool_(thread_pool), event_(event) {}
80+
: async_work_runner_(async_work_runner), event_(event) {}
8181

8282
static tsl::AsyncValueRef<BufferSequencingEvent> Create(
83-
tsl::thread::ThreadPool* thread_pool) {
83+
AsyncWorkRunner* async_work_runner) {
8484
return tsl::MakeConstructedAsyncValueRef<BufferSequencingEvent>(
85-
thread_pool);
85+
async_work_runner);
8686
}
8787

8888
// Sets the sequencing event to 'event', which is recorded on 'stream'. Must
@@ -164,7 +164,7 @@ class BufferSequencingEvent : tsl::AsyncPayload::KeepOnError {
164164
// at the tail of the queue, i.e., for any newly enqueued command.
165165
absl::InlinedVector<se::Stream*, 2> streams_defined_on_ ABSL_GUARDED_BY(mu_);
166166

167-
tsl::thread::ThreadPool* thread_pool_;
167+
AsyncWorkRunner* async_work_runner_;
168168

169169
// Indicates if the buffer is in an error status. And error status is used to
170170
// propagate the error to the buffer consumers.

xla/pjrt/gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ cc_library(
7979
"//xla/core/collectives:rank_id",
8080
"//xla/hlo/builder:xla_computation",
8181
"//xla/pjrt:abstract_tracked_device_buffer",
82+
"//xla/pjrt:async_work_runner",
8283
"//xla/pjrt:common_pjrt_client",
8384
"//xla/pjrt:device_event",
8485
"//xla/pjrt:event_pool",

xla/pjrt/gpu/se_gpu_pjrt_client.cc

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ limitations under the License.
6464
#include "xla/layout.h"
6565
#include "xla/literal.h"
6666
#include "xla/pjrt/abstract_tracked_device_buffer.h"
67+
#include "xla/pjrt/async_work_runner.h"
6768
#include "xla/pjrt/buffer_sequencing_event.h"
6869
#include "xla/pjrt/common_pjrt_client.h"
6970
#include "xla/pjrt/device_event.h"
@@ -152,11 +153,11 @@ limitations under the License.
152153
namespace xla {
153154

154155
absl::Status RunCallbackOnStream(se::Stream* stream,
155-
tsl::thread::ThreadPool* thread_pool,
156+
AsyncWorkRunner* async_work_runner,
156157
absl::AnyInvocable<void() &&> callback) {
157158
return stream->DoHostCallbackWithStatus(
158-
[cb = std::move(callback), thread_pool]() mutable {
159-
thread_pool->Schedule(
159+
[cb = std::move(callback), async_work_runner]() mutable {
160+
async_work_runner->Schedule(
160161
[cb_ptr = new absl::AnyInvocable<void() &&>(std::move(cb))]() {
161162
std::move (*cb_ptr)();
162163
delete cb_ptr;
@@ -761,7 +762,7 @@ void StreamExecutorGpuClient::ScheduleSendsOnLocalDevice(
761762
gpu::GpuCollectives* gpu_collectives =
762763
gpu::GpuCollectives::Default(stream->parent()->GetPlatform()->Name());
763764
usage_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
764-
BufferSequencingEvent::Create(this->thread_pool()));
765+
BufferSequencingEvent::Create(this->async_work_runner()));
765766

766767
gpu::AcquiredCliquesMap acquired_cliques_map;
767768
for (int i = 0; i < buffers.size(); ++i) {
@@ -853,7 +854,7 @@ void StreamExecutorGpuClient::ScheduleSendsOnLocalDevice(
853854
Future<> all_sends_future = JoinFutures(group_futures);
854855

855856
all_sends_future.OnReady(
856-
*this->thread_pool()->AsExecutor(),
857+
this->async_work_runner()->AsExecutor(),
857858
[this, local_device_state, stream, promises = std::move(promises),
858859
usage_event, grouped_sends = std::move(grouped_sends)](
859860
const absl::Status& status) mutable {
@@ -870,7 +871,7 @@ void StreamExecutorGpuClient::ScheduleSendsOnLocalDevice(
870871
// Asynchronously fulfill promises via a host callback, failing them
871872
// early if there is an issue registering the callback.
872873
absl::Status callback_status = RunCallbackOnStream(
873-
stream, this->thread_pool(), [promises]() mutable {
874+
stream, this->async_work_runner(), [promises]() mutable {
874875
FulfillPromises(promises, absl::OkStatus());
875876
});
876877

@@ -911,7 +912,7 @@ StreamExecutorGpuClient::PrepareReceiveBuffer(PjRtDevice* device, Shape shape) {
911912
se::Stream* stream = local_device->GetDeviceToDeviceStream();
912913

913914
BufferSequencingEventRef definition_event =
914-
BufferSequencingEvent::Create(this->thread_pool());
915+
BufferSequencingEvent::Create(this->async_work_runner());
915916
TF_ASSIGN_OR_RETURN(
916917
auto buffer,
917918
DefineBuffer(
@@ -981,7 +982,7 @@ StreamExecutorGpuClient::CrossHostReceiveBuffers(
981982
gpu::GpuCollectives* gpu_collectives =
982983
gpu::GpuCollectives::Default(stream->parent()->GetPlatform()->Name());
983984
definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
984-
BufferSequencingEvent::Create(this->thread_pool()));
985+
BufferSequencingEvent::Create(this->async_work_runner()));
985986

986987
gpu::AcquiredCliquesMap acquired_cliques_map;
987988
for (int i = 0; i < shapes.size(); ++i) {
@@ -1064,7 +1065,7 @@ StreamExecutorGpuClient::CrossHostReceiveBuffers(
10641065
Future<> all_receives_future = JoinFutures(group_futures);
10651066

10661067
all_receives_future.OnReady(
1067-
*this->thread_pool()->AsExecutor(),
1068+
this->async_work_runner()->AsExecutor(),
10681069
[this, local_device_state, stream,
10691070
grouped_receives = std::move(grouped_receives),
10701071
definition_event = std::move(definition_event)](
@@ -1105,7 +1106,7 @@ void StreamExecutorGpuClient::ScheduleRemoteSend(
11051106
}
11061107

11071108
BufferSequencingEventRef usage_event =
1108-
BufferSequencingEvent::Create(this->thread_pool());
1109+
BufferSequencingEvent::Create(this->async_work_runner());
11091110

11101111
// Keep memory alive until the event is done.
11111112
usage_event.AndThen([raw_buffer]() {});
@@ -1259,7 +1260,7 @@ StreamExecutorGpuClient::MakeCrossHostReceiveBuffers(
12591260
SetEventAsError(definition_event, s);
12601261
}
12611262
};
1262-
thread_pool()->Schedule(recv);
1263+
async_work_runner()->Schedule(recv);
12631264

12641265
std::vector<std::unique_ptr<PjRtBuffer>> buffers;
12651266
buffers.push_back(std::move(receive_prep_result.buffer));

0 commit comments

Comments
 (0)