Skip to content

Commit f4bab1d

Browse files
junwhanahnGoogle-ML-Automation
authored andcommitted
Use an unbounded thread pool for StreamExecutor PjRt clients
If `tsl::thread::ThreadPool` runs out of threads, it runs the callback inline on the thread that called `Schedule()`. This pattern is known to be prone to deadlocks (e.g., if `Schedule` is called while holding a mutex that the callable needs to acquire, it will deadlock) and may harm the asynchrony in the tail case, so switching to an unbounded thread pool based on `tsl::UnboundedWorkQueue`. Since `tsl::thread::ThreadPool` is not extensible, this requires replacing the executor type from `tsl::thread::ThreadPool` to `AsyncWorkRunner`. The existing thread pool is still kept because `ExecutableBuildOptions::set_compile_thread_pool()` requires `tsl::thread::ThreadPool`. PiperOrigin-RevId: 853352782
1 parent 4c76b1a commit f4bab1d

18 files changed

+232
-115
lines changed

xla/pjrt/BUILD

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ cc_library(
219219
"local_device_state.h",
220220
],
221221
deps = [
222+
":async_work_runner",
222223
":event_pool",
223224
":pjrt_common",
224225
":semaphore",
@@ -676,6 +677,7 @@ cc_library(
676677
visibility = internal_visibility(["//xla:friends"]),
677678
deps = [
678679
":abstract_tracked_device_buffer",
680+
":async_work_runner",
679681
":common_pjrt_client",
680682
":device_event",
681683
":event_pool",
@@ -1263,10 +1265,12 @@ cc_library(
12631265

12641266
cc_library(
12651267
name = "async_work_runner",
1268+
srcs = ["async_work_runner.cc"],
12661269
hdrs = ["async_work_runner.h"],
12671270
visibility = internal_visibility([":friends"]),
12681271
deps = [
12691272
"//xla/tsl/concurrency:async_value",
1273+
"//xla/tsl/concurrency:executor",
12701274
"//xla/tsl/concurrency:ref_count",
12711275
"@com_google_absl//absl/functional:any_invocable",
12721276
"@com_google_absl//absl/types:span",
@@ -1284,6 +1288,7 @@ cc_library(
12841288
"//xla/tsl/platform:env",
12851289
"@com_google_absl//absl/functional:any_invocable",
12861290
"@com_google_absl//absl/types:span",
1291+
"@tsl//tsl/platform:unbounded_work_queue",
12871292
],
12881293
)
12891294

xla/pjrt/async_work_runner.cc

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/* Copyright 2026 The OpenXLA Authors.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
#include "xla/pjrt/async_work_runner.h"
17+
18+
#include <memory>
19+
#include <utility>
20+
21+
#include "xla/tsl/concurrency/executor.h"
22+
23+
namespace xla {
24+
25+
namespace {
26+
27+
class AsyncWorkRunnerExecutor : public tsl::Executor {
28+
public:
29+
explicit AsyncWorkRunnerExecutor(AsyncWorkRunner* runner) : runner_(runner) {}
30+
31+
void Execute(Task task) override { runner_->Schedule(std::move(task)); }
32+
33+
private:
34+
AsyncWorkRunner* const runner_;
35+
};
36+
37+
} // namespace
38+
39+
AsyncWorkRunner::AsyncWorkRunner()
40+
: executor_(std::make_unique<AsyncWorkRunnerExecutor>(this)) {}
41+
42+
} // namespace xla

xla/pjrt/async_work_runner.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,12 @@ limitations under the License.
1616
#ifndef XLA_PJRT_ASYNC_WORK_RUNNER_H_
1717
#define XLA_PJRT_ASYNC_WORK_RUNNER_H_
1818

19+
#include <memory>
20+
1921
#include "absl/functional/any_invocable.h"
2022
#include "absl/types/span.h"
2123
#include "xla/tsl/concurrency/async_value.h"
24+
#include "xla/tsl/concurrency/executor.h"
2225
#include "xla/tsl/concurrency/ref_count.h"
2326

2427
namespace xla {
@@ -27,13 +30,22 @@ namespace xla {
2730
// pool (or concurrent work queue).
2831
class AsyncWorkRunner {
2932
public:
33+
AsyncWorkRunner();
3034
virtual ~AsyncWorkRunner() = default;
3135

3236
// `work` euqueued by `Schedule` may run on the calling thread.
3337
virtual void Schedule(absl::AnyInvocable<void() &&> work) = 0;
3438
virtual void ScheduleWhenReady(
3539
absl::Span<const tsl::RCReference<tsl::AsyncValue>> values,
3640
absl::AnyInvocable<void() &&> work) = 0;
41+
42+
// Returns an tsl::Executor implementation that is backed by this async work
43+
// runner. The returned executor is owned by the async work runner and its
44+
// lifetime is bound to the lifetime of the thread pool itself.
45+
virtual tsl::Executor& AsExecutor() { return *executor_; }
46+
47+
private:
48+
std::unique_ptr<tsl::Executor> executor_;
3749
};
3850

3951
} // namespace xla

xla/pjrt/buffer_sequencing_event.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ void BufferSequencingEvent::ExecuteOrAddToFutureTasks(
133133
// Execute the `task` when definition event becomes available. If it's already
134134
// available, the task will be executed immediately.
135135
event_.AndThen([this, traced_task = std::move(traced_task)]() mutable {
136-
thread_pool_->Schedule(std::move(traced_task));
136+
async_work_runner_->Schedule(std::move(traced_task));
137137
});
138138
}
139139

xla/pjrt/buffer_sequencing_event.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ limitations under the License.
2626
#include "absl/log/log.h"
2727
#include "absl/status/status.h"
2828
#include "absl/synchronization/mutex.h"
29+
#include "xla/pjrt/async_work_runner.h"
2930
#include "xla/pjrt/event_pool.h"
3031
#include "xla/stream_executor/stream.h"
3132
#include "xla/tsl/concurrency/async_value.h"
3233
#include "xla/tsl/concurrency/async_value_ref.h"
33-
#include "xla/tsl/platform/threadpool.h"
3434

3535
namespace xla {
3636

@@ -71,18 +71,18 @@ class BufferSequencingEvent : tsl::AsyncPayload::KeepOnError {
7171
se::Stream* definition_stream;
7272
};
7373

74-
explicit BufferSequencingEvent(tsl::thread::ThreadPool* thread_pool)
75-
: thread_pool_(thread_pool),
74+
explicit BufferSequencingEvent(AsyncWorkRunner* async_work_runner)
75+
: async_work_runner_(async_work_runner),
7676
event_(tsl::MakeUnconstructedAsyncValueRef<EventState>()) {}
7777

78-
explicit BufferSequencingEvent(tsl::thread::ThreadPool* thread_pool,
78+
explicit BufferSequencingEvent(AsyncWorkRunner* async_work_runner,
7979
tsl::AsyncValueRef<EventState> event)
80-
: thread_pool_(thread_pool), event_(event) {}
80+
: async_work_runner_(async_work_runner), event_(event) {}
8181

8282
static tsl::AsyncValueRef<BufferSequencingEvent> Create(
83-
tsl::thread::ThreadPool* thread_pool) {
83+
AsyncWorkRunner* async_work_runner) {
8484
return tsl::MakeConstructedAsyncValueRef<BufferSequencingEvent>(
85-
thread_pool);
85+
async_work_runner);
8686
}
8787

8888
// Sets the sequencing event to 'event', which is recorded on 'stream'. Must
@@ -164,7 +164,7 @@ class BufferSequencingEvent : tsl::AsyncPayload::KeepOnError {
164164
// at the tail of the queue, i.e., for any newly enqueued command.
165165
absl::InlinedVector<se::Stream*, 2> streams_defined_on_ ABSL_GUARDED_BY(mu_);
166166

167-
tsl::thread::ThreadPool* thread_pool_;
167+
AsyncWorkRunner* async_work_runner_;
168168

169169
// Indicates if the buffer is in an error status. And error status is used to
170170
// propagate the error to the buffer consumers.

xla/pjrt/gpu/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ cc_library(
7979
"//xla/core/collectives:rank_id",
8080
"//xla/hlo/builder:xla_computation",
8181
"//xla/pjrt:abstract_tracked_device_buffer",
82+
"//xla/pjrt:async_work_runner",
8283
"//xla/pjrt:common_pjrt_client",
8384
"//xla/pjrt:device_event",
8485
"//xla/pjrt:event_pool",

xla/pjrt/gpu/se_gpu_pjrt_client.cc

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ limitations under the License.
6464
#include "xla/layout.h"
6565
#include "xla/literal.h"
6666
#include "xla/pjrt/abstract_tracked_device_buffer.h"
67+
#include "xla/pjrt/async_work_runner.h"
6768
#include "xla/pjrt/buffer_sequencing_event.h"
6869
#include "xla/pjrt/common_pjrt_client.h"
6970
#include "xla/pjrt/device_event.h"
@@ -152,11 +153,11 @@ limitations under the License.
152153
namespace xla {
153154

154155
absl::Status RunCallbackOnStream(se::Stream* stream,
155-
tsl::thread::ThreadPool* thread_pool,
156+
AsyncWorkRunner* async_work_runner,
156157
absl::AnyInvocable<void() &&> callback) {
157158
return stream->DoHostCallbackWithStatus(
158-
[cb = std::move(callback), thread_pool]() mutable {
159-
thread_pool->Schedule(
159+
[cb = std::move(callback), async_work_runner]() mutable {
160+
async_work_runner->Schedule(
160161
[cb_ptr = new absl::AnyInvocable<void() &&>(std::move(cb))]() {
161162
std::move (*cb_ptr)();
162163
delete cb_ptr;
@@ -761,7 +762,7 @@ void StreamExecutorGpuClient::ScheduleSendsOnLocalDevice(
761762
gpu::GpuCollectives* gpu_collectives =
762763
gpu::GpuCollectives::Default(stream->parent()->GetPlatform()->Name());
763764
usage_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
764-
BufferSequencingEvent::Create(this->thread_pool()));
765+
BufferSequencingEvent::Create(this->async_work_runner()));
765766

766767
gpu::AcquiredCliquesMap acquired_cliques_map;
767768
for (int i = 0; i < buffers.size(); ++i) {
@@ -853,7 +854,7 @@ void StreamExecutorGpuClient::ScheduleSendsOnLocalDevice(
853854
Future<> all_sends_future = JoinFutures(group_futures);
854855

855856
all_sends_future.OnReady(
856-
*this->thread_pool()->AsExecutor(),
857+
this->async_work_runner()->AsExecutor(),
857858
[this, local_device_state, stream, promises = std::move(promises),
858859
usage_event, grouped_sends = std::move(grouped_sends)](
859860
const absl::Status& status) mutable {
@@ -870,7 +871,7 @@ void StreamExecutorGpuClient::ScheduleSendsOnLocalDevice(
870871
// Asynchronously fulfill promises via a host callback, failing them
871872
// early if there is an issue registering the callback.
872873
absl::Status callback_status = RunCallbackOnStream(
873-
stream, this->thread_pool(), [promises]() mutable {
874+
stream, this->async_work_runner(), [promises]() mutable {
874875
FulfillPromises(promises, absl::OkStatus());
875876
});
876877

@@ -911,7 +912,7 @@ StreamExecutorGpuClient::PrepareReceiveBuffer(PjRtDevice* device, Shape shape) {
911912
se::Stream* stream = local_device->GetDeviceToDeviceStream();
912913

913914
BufferSequencingEventRef definition_event =
914-
BufferSequencingEvent::Create(this->thread_pool());
915+
BufferSequencingEvent::Create(this->async_work_runner());
915916
TF_ASSIGN_OR_RETURN(
916917
auto buffer,
917918
DefineBuffer(
@@ -981,7 +982,7 @@ StreamExecutorGpuClient::CrossHostReceiveBuffers(
981982
gpu::GpuCollectives* gpu_collectives =
982983
gpu::GpuCollectives::Default(stream->parent()->GetPlatform()->Name());
983984
definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
984-
BufferSequencingEvent::Create(this->thread_pool()));
985+
BufferSequencingEvent::Create(this->async_work_runner()));
985986

986987
gpu::AcquiredCliquesMap acquired_cliques_map;
987988
for (int i = 0; i < shapes.size(); ++i) {
@@ -1064,7 +1065,7 @@ StreamExecutorGpuClient::CrossHostReceiveBuffers(
10641065
Future<> all_receives_future = JoinFutures(group_futures);
10651066

10661067
all_receives_future.OnReady(
1067-
*this->thread_pool()->AsExecutor(),
1068+
this->async_work_runner()->AsExecutor(),
10681069
[this, local_device_state, stream,
10691070
grouped_receives = std::move(grouped_receives),
10701071
definition_event = std::move(definition_event)](
@@ -1105,7 +1106,7 @@ void StreamExecutorGpuClient::ScheduleRemoteSend(
11051106
}
11061107

11071108
BufferSequencingEventRef usage_event =
1108-
BufferSequencingEvent::Create(this->thread_pool());
1109+
BufferSequencingEvent::Create(this->async_work_runner());
11091110

11101111
// Keep memory alive until the event is done.
11111112
usage_event.AndThen([raw_buffer]() {});
@@ -1259,7 +1260,7 @@ StreamExecutorGpuClient::MakeCrossHostReceiveBuffers(
12591260
SetEventAsError(definition_event, s);
12601261
}
12611262
};
1262-
thread_pool()->Schedule(recv);
1263+
async_work_runner()->Schedule(recv);
12631264

12641265
std::vector<std::unique_ptr<PjRtBuffer>> buffers;
12651266
buffers.push_back(std::move(receive_prep_result.buffer));

xla/pjrt/gpu/se_gpu_pjrt_client_test.cc

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2837,9 +2837,9 @@ ENTRY main.5 {
28372837
TEST(StreamExecutorGpuClientTest, EventCaching) {
28382838
TF_ASSERT_OK_AND_ASSIGN(auto client,
28392839
GetStreamExecutorGpuClient(DefaultOptions()));
2840-
auto* thread_pool =
2840+
auto* async_work_runner =
28412841
tensorflow::down_cast<PjRtStreamExecutorClient*>(client.get())
2842-
->thread_pool();
2842+
->async_work_runner();
28432843
const auto& device = client->addressable_devices()[0];
28442844
LocalDeviceState* local_device_state =
28452845
tensorflow::down_cast<const PjRtStreamExecutorDevice*>(device)
@@ -2848,14 +2848,14 @@ TEST(StreamExecutorGpuClientTest, EventCaching) {
28482848
size_t sync_point0 = local_device_state->GetNextComputeStreamSyncPoint();
28492849
TF_ASSERT_OK_AND_ASSIGN(auto event0,
28502850
local_device_state->GetEventForComputeStreamSyncPoint(
2851-
sync_point0, thread_pool));
2851+
sync_point0, async_work_runner));
28522852
TF_ASSERT_OK_AND_ASSIGN(auto event1,
28532853
local_device_state->GetEventForComputeStreamSyncPoint(
2854-
sync_point0, thread_pool));
2854+
sync_point0, async_work_runner));
28552855
size_t sync_point1 = local_device_state->GetNextComputeStreamSyncPoint();
28562856
TF_ASSERT_OK_AND_ASSIGN(auto event2,
28572857
local_device_state->GetEventForComputeStreamSyncPoint(
2858-
sync_point1, thread_pool));
2858+
sync_point1, async_work_runner));
28592859
// Events are getting cached.
28602860
EXPECT_EQ(&*event0, &*event1);
28612861
// New events are getting assigned.
@@ -2864,7 +2864,7 @@ TEST(StreamExecutorGpuClientTest, EventCaching) {
28642864
// sync_point1 is ready, so it is the most recent event.
28652865
TF_ASSERT_OK_AND_ASSIGN(auto event3,
28662866
local_device_state->GetEventForComputeStreamSyncPoint(
2867-
sync_point0, thread_pool));
2867+
sync_point0, async_work_runner));
28682868
EXPECT_EQ(&*event3, &*event2);
28692869
}
28702870

xla/pjrt/local_device_state.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ limitations under the License.
3232
#include "absl/strings/str_format.h"
3333
#include "absl/synchronization/mutex.h"
3434
#include "xla/client/local_client.h"
35+
#include "xla/pjrt/async_work_runner.h"
3536
#include "xla/pjrt/buffer_sequencing_event.h"
3637
#include "xla/pjrt/worker_thread.h"
3738
#include "xla/stream_executor/device_address.h"
@@ -325,7 +326,7 @@ absl::Status LocalDeviceState::AllocateAndRecordEvent(
325326

326327
absl::StatusOr<BufferSequencingEventRef>
327328
LocalDeviceState::GetEventForComputeStreamSyncPoint(
328-
size_t sync_point, tsl::thread::ThreadPool* thread_pool,
329+
size_t sync_point, AsyncWorkRunner* async_work_runner,
329330
bool nullptr_if_past) {
330331
mu_.lock();
331332
size_t cur_sync_point = next_compute_stream_sync_point_.load();
@@ -343,7 +344,7 @@ LocalDeviceState::GetEventForComputeStreamSyncPoint(
343344
return event;
344345
}
345346
next_compute_stream_sync_point_.store(cur_sync_point + 1);
346-
auto event = BufferSequencingEvent::Create(thread_pool);
347+
auto event = BufferSequencingEvent::Create(async_work_runner);
347348
auto status = AllocateAndRecordEvent(event, compute_stream());
348349
if (!status.ok()) {
349350
mu_.unlock();

xla/pjrt/local_device_state.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ limitations under the License.
3131
#include "absl/status/statusor.h"
3232
#include "absl/synchronization/mutex.h"
3333
#include "xla/client/local_client.h"
34+
#include "xla/pjrt/async_work_runner.h"
3435
#include "xla/pjrt/buffer_sequencing_event.h"
3536
#include "xla/pjrt/event_pool.h"
3637
#include "xla/pjrt/pjrt_common.h"
@@ -224,7 +225,7 @@ class LocalDeviceState {
224225
// which only incur the expense of constructing a cuda event if they're really
225226
// needed. This allows constructing a definition event per buffer.
226227
absl::StatusOr<BufferSequencingEventRef> GetEventForComputeStreamSyncPoint(
227-
size_t sync_point, tsl::thread::ThreadPool* thread_pool,
228+
size_t sync_point, AsyncWorkRunner* async_work_runner,
228229
bool nullptr_if_past = false);
229230

230231
private:

0 commit comments

Comments
 (0)