Skip to content

Commit 43eb396

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Add StatusOr to transfer server BulkTransportInterface on the bond id to
forward errors from bond connection failures to the control plane connection. PiperOrigin-RevId: 820783819
1 parent 312c441 commit 43eb396

File tree

8 files changed

+51
-10
lines changed

8 files changed

+51
-10
lines changed

xla/python/transfer/socket-server.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,14 @@ class SocketServer::SocketNetworkState : public PollEventLoop::Handler {
299299
msg.data = const_cast<void*>(data);
300300
msg.size = size;
301301
msg.on_send = [val = tsl::FormRef(this), offset, req_id, is_largest](
302-
int bond_id, size_t size) {
302+
absl::StatusOr<int> bond_id, size_t size) {
303+
if (!bond_id.ok()) {
304+
val->SendError(req_id, offset, size, is_largest, bond_id.status());
305+
return;
306+
}
303307
SocketTransferRequest response;
304308
auto* packet = response.mutable_packet();
305-
packet->set_bulk_transport_id(bond_id);
309+
packet->set_bulk_transport_id(*bond_id);
306310
packet->set_offset(offset);
307311
packet->set_size(size);
308312
packet->set_req_id(req_id);

xla/python/transfer/socket_bulk_transport.cc

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,7 @@ class SendConnectionHandler : public PollEventLoop::Handler {
180180
artificial_send_limit_(artificial_send_limit) {}
181181

182182
~SendConnectionHandler() override {
183+
msg_queue_->Poison(absl::InternalError("A send connection has failed."));
183184
#ifdef MSG_ZEROCOPY
184185
table_.ClearAll();
185186
#endif
@@ -356,6 +357,19 @@ std::shared_ptr<SharedSendWorkQueue> SharedSendWorkQueue::Start() {
356357
return result;
357358
}
358359

360+
void SharedSendMsgQueue::Poison(absl::Status s) {
361+
mu_.lock();
362+
poison_status_ = s;
363+
auto work_items = std::move(work_items_);
364+
mu_.unlock();
365+
while (!work_items.empty()) {
366+
auto work = std::move(work_items.front());
367+
work_items.pop_front();
368+
std::move(work.on_send)(s, work.size);
369+
std::move(work.on_done)();
370+
}
371+
}
372+
359373
void SharedSendMsgQueue::ReportReadyToSend(SendConnectionHandler* handler) {
360374
mu_.lock();
361375
if (!work_items_.empty()) {
@@ -375,6 +389,13 @@ void SharedSendMsgQueue::ReportReadyToSend(SendConnectionHandler* handler) {
375389
void SharedSendMsgQueue::ScheduleSendWork(
376390
aux::BulkTransportInterface::SendMessage msg) {
377391
mu_.lock();
392+
if (!poison_status_.ok()) {
393+
auto s = poison_status_;
394+
mu_.unlock();
395+
std::move(msg.on_send)(std::move(s), msg.size);
396+
std::move(msg.on_done)();
397+
return;
398+
}
378399
DCHECK(!shutdown_);
379400
if (work_items_.empty() && !handlers_.empty()) {
380401
auto* handler = handlers_.front();

xla/python/transfer/socket_bulk_transport.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,13 +129,16 @@ class SharedSendMsgQueue {
129129
std::shared_ptr<SharedSendWorkQueue> work_queue,
130130
size_t artificial_send_limiti = std::numeric_limits<size_t>::max());
131131

132+
void Poison(absl::Status s);
133+
132134
private:
133135
friend class SendConnectionHandler;
134136

135137
void ReportReadyToSend(SendConnectionHandler* handler);
136138

137139
absl::Mutex mu_;
138140
bool shutdown_ = false;
141+
absl::Status poison_status_;
139142
std::deque<SendConnectionHandler*> handlers_;
140143
std::deque<aux::BulkTransportInterface::SendMessage> work_items_;
141144
};

xla/python/transfer/socket_bulk_transport_test.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ TEST(SendQueue, TestZeroCopyQueueCleanRemoteShutdown) {
8585
BulkTransportInterface::SendMessage msg;
8686
msg.data = txt_msg.data();
8787
msg.size = txt_msg.size();
88-
msg.on_send = [](int id, size_t size) {};
88+
msg.on_send = [](absl::StatusOr<int> id, size_t size) {};
8989
msg.on_done = [&notify]() { notify.Notify(); };
9090
msg_queue->ScheduleSendWork(std::move(msg));
9191
notify.WaitForNotification();
@@ -124,7 +124,7 @@ TEST(SendQueue, SendAndRecvQueuesArtificialLimit) {
124124
BulkTransportInterface::SendMessage msg;
125125
msg.data = txt_msg.data();
126126
msg.size = txt_msg.size();
127-
msg.on_send = [](int id, size_t size) {};
127+
msg.on_send = [](absl::StatusOr<int> id, size_t size) {};
128128
msg.on_done = [&mu, &send_count]() {
129129
absl::MutexLock l(mu);
130130
--send_count;
@@ -230,9 +230,9 @@ TEST(SocketBulkTransportFactoryTest, SendAndRecvWithFactory) {
230230
BulkTransportInterface::SendMessage msg;
231231
msg.data = txt_msgs[i].data();
232232
msg.size = txt_msgs[i].size();
233-
msg.on_send = [&, i](int id, size_t size) {
233+
msg.on_send = [&, i](absl::StatusOr<int> id, size_t size) {
234234
absl::MutexLock l(mu);
235-
send_queue.push_back({i, id});
235+
send_queue.push_back({i, id.value()});
236236
};
237237
msg.on_done = [&mu, &send_count]() {
238238
absl::MutexLock l(mu);

xla/python/transfer/streaming.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,10 @@ BulkTransportInterface::SendMessage BulkTransportInterface::MakeMessage(
151151
SendMessage result;
152152
result.data = tmp->data();
153153
result.size = tmp->size();
154-
result.on_send = std::move(on_send);
154+
result.on_send = [on_send = std::move(on_send)](absl::StatusOr<int> bond_id,
155+
size_t size) mutable {
156+
std::move(on_send)(bond_id.value(), size);
157+
};
155158
result.on_done = [tmp = std::move(tmp)]() {};
156159
return result;
157160
}

xla/python/transfer/streaming.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,8 @@ class BulkTransportInterface {
107107
// There may be some delay between Send() and when the message
108108
// is actually sent. on_send gets called when the message actually
109109
// gets sent.
110-
absl::AnyInvocable<void(int bond_id, size_t size) &&> on_send;
110+
absl::AnyInvocable<void(absl::StatusOr<int> bond_id, size_t size) &&>
111+
on_send;
111112
};
112113

113114
// Schedules a send over a BulkTransportInterface connection.

xla/python/transfer/streaming_ifrt.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,9 @@ void PremappedCopierState::StartWorkUnlocked(const WorkList& work_list) {
180180
--num_parallel_copies_;
181181
work_item->is_ready = true;
182182
work_item->result_status = s;
183-
FlushReadyWorkItemsInOrder();
183+
if (!currently_flushing_) {
184+
FlushReadyWorkItemsInOrder();
185+
}
184186
work_list2 = FindWorkLocked();
185187
}
186188
StartWorkUnlocked(work_list2);
@@ -194,14 +196,20 @@ void PremappedCopierState::FlushReadyWorkItemsInOrder() {
194196
if (!work_item->is_ready) {
195197
return;
196198
}
199+
if (!work_item->result_status.ok()) {
200+
available_copy_offsets_.push_back(work_item->dest_buffer);
201+
}
202+
currently_flushing_ = true;
203+
mu_.unlock();
197204
if (work_item->result_status.ok()) {
198205
std::move(work_item->on_done)(this, work_item->dest_buffer,
199206
work_item->work);
200207
} else {
201208
std::move(work_item->on_done)(this, work_item->result_status,
202209
work_item->work);
203-
available_copy_offsets_.push_back(work_item->dest_buffer);
204210
}
211+
mu_.lock();
212+
currently_flushing_ = false;
205213
work_queue_.pop_front();
206214
++base_seq_id_;
207215
}

xla/python/transfer/streaming_ifrt.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ class PremappedCopierState
117117
size_t num_parallel_copies_ = 0;
118118
std::deque<WorkQueueItem> work_queue_ ABSL_GUARDED_BY(mu_);
119119
std::shared_ptr<absl::Span<uint8_t>> scratch_;
120+
bool currently_flushing_ ABSL_GUARDED_BY(mu_) = false;
120121
size_t max_num_parallel_copies_;
121122
size_t xfer_size_;
122123
size_t max_copies_;

0 commit comments

Comments
 (0)