@@ -64,6 +64,7 @@ limitations under the License.
6464#include " xla/layout.h"
6565#include " xla/literal.h"
6666#include " xla/pjrt/abstract_tracked_device_buffer.h"
67+ #include " xla/pjrt/async_work_runner.h"
6768#include " xla/pjrt/buffer_sequencing_event.h"
6869#include " xla/pjrt/common_pjrt_client.h"
6970#include " xla/pjrt/device_event.h"
@@ -152,11 +153,11 @@ limitations under the License.
152153namespace xla {
153154
154155absl::Status RunCallbackOnStream (se::Stream* stream,
155- tsl::thread::ThreadPool* thread_pool ,
156+ AsyncWorkRunner* async_work_runner ,
156157 absl::AnyInvocable<void () &&> callback) {
157158 return stream->DoHostCallbackWithStatus (
158- [cb = std::move (callback), thread_pool ]() mutable {
159- thread_pool ->Schedule (
159+ [cb = std::move (callback), async_work_runner ]() mutable {
160+ async_work_runner ->Schedule (
160161 [cb_ptr = new absl::AnyInvocable<void () &&>(std::move (cb))]() {
161162 std::move (*cb_ptr)();
162163 delete cb_ptr;
@@ -761,7 +762,7 @@ void StreamExecutorGpuClient::ScheduleSendsOnLocalDevice(
761762 gpu::GpuCollectives* gpu_collectives =
762763 gpu::GpuCollectives::Default (stream->parent ()->GetPlatform ()->Name ());
763764 usage_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
764- BufferSequencingEvent::Create (this ->thread_pool ()));
765+ BufferSequencingEvent::Create (this ->async_work_runner ()));
765766
766767 gpu::AcquiredCliquesMap acquired_cliques_map;
767768 for (int i = 0 ; i < buffers.size (); ++i) {
@@ -853,7 +854,7 @@ void StreamExecutorGpuClient::ScheduleSendsOnLocalDevice(
853854 Future<> all_sends_future = JoinFutures (group_futures);
854855
855856 all_sends_future.OnReady (
856- * this ->thread_pool ()->AsExecutor (),
857+ this ->async_work_runner ()->AsExecutor (),
857858 [this , local_device_state, stream, promises = std::move (promises),
858859 usage_event, grouped_sends = std::move (grouped_sends)](
859860 const absl::Status& status) mutable {
@@ -870,7 +871,7 @@ void StreamExecutorGpuClient::ScheduleSendsOnLocalDevice(
870871 // Asynchronously fulfill promises via a host callback, failing them
871872 // early if there is an issue registering the callback.
872873 absl::Status callback_status = RunCallbackOnStream (
873- stream, this ->thread_pool (), [promises]() mutable {
874+ stream, this ->async_work_runner (), [promises]() mutable {
874875 FulfillPromises (promises, absl::OkStatus ());
875876 });
876877
@@ -911,7 +912,7 @@ StreamExecutorGpuClient::PrepareReceiveBuffer(PjRtDevice* device, Shape shape) {
911912 se::Stream* stream = local_device->GetDeviceToDeviceStream ();
912913
913914 BufferSequencingEventRef definition_event =
914- BufferSequencingEvent::Create (this ->thread_pool ());
915+ BufferSequencingEvent::Create (this ->async_work_runner ());
915916 TF_ASSIGN_OR_RETURN (
916917 auto buffer,
917918 DefineBuffer (
@@ -981,7 +982,7 @@ StreamExecutorGpuClient::CrossHostReceiveBuffers(
981982 gpu::GpuCollectives* gpu_collectives =
982983 gpu::GpuCollectives::Default (stream->parent ()->GetPlatform ()->Name ());
983984 definition_event = tsl::MakeRef<PjRtStreamExecutorDeviceEvent>(
984- BufferSequencingEvent::Create (this ->thread_pool ()));
985+ BufferSequencingEvent::Create (this ->async_work_runner ()));
985986
986987 gpu::AcquiredCliquesMap acquired_cliques_map;
987988 for (int i = 0 ; i < shapes.size (); ++i) {
@@ -1064,7 +1065,7 @@ StreamExecutorGpuClient::CrossHostReceiveBuffers(
10641065 Future<> all_receives_future = JoinFutures (group_futures);
10651066
10661067 all_receives_future.OnReady (
1067- * this ->thread_pool ()->AsExecutor (),
1068+ this ->async_work_runner ()->AsExecutor (),
10681069 [this , local_device_state, stream,
10691070 grouped_receives = std::move (grouped_receives),
10701071 definition_event = std::move (definition_event)](
@@ -1105,7 +1106,7 @@ void StreamExecutorGpuClient::ScheduleRemoteSend(
11051106 }
11061107
11071108 BufferSequencingEventRef usage_event =
1108- BufferSequencingEvent::Create (this ->thread_pool ());
1109+ BufferSequencingEvent::Create (this ->async_work_runner ());
11091110
11101111 // Keep memory alive until the event is done.
11111112 usage_event.AndThen ([raw_buffer]() {});
@@ -1259,7 +1260,7 @@ StreamExecutorGpuClient::MakeCrossHostReceiveBuffers(
12591260 SetEventAsError (definition_event, s);
12601261 }
12611262 };
1262- thread_pool ()->Schedule (recv);
1263+ async_work_runner ()->Schedule (recv);
12631264
12641265 std::vector<std::unique_ptr<PjRtBuffer>> buffers;
12651266 buffers.push_back (std::move (receive_prep_result.buffer ));
0 commit comments