Skip to content

Commit fadb0e8

Browse files
pschuhGoogle-ML-Automation
authored andcommitted
Update PremappedCopierState::FlushReadyWorkItemsInOrder to ensure on_done gests
destroyed as part of invocation. PiperOrigin-RevId: 827739079
1 parent 0506f18 commit fadb0e8

File tree

2 files changed

+30
-24
lines changed

2 files changed

+30
-24
lines changed

xla/python/transfer/streaming_ifrt.cc

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,13 @@ void PremappedCopierState::FlushReadyWorkItemsInOrder() {
201201
}
202202
currently_flushing_ = true;
203203
mu_.unlock();
204-
if (work_item->result_status.ok()) {
205-
std::move(work_item->on_done)(this, work_item->dest_buffer,
206-
work_item->work);
207-
} else {
208-
std::move(work_item->on_done)(this, work_item->result_status,
209-
work_item->work);
204+
{
205+
auto on_done_fn = std::move(work_item->on_done);
206+
if (work_item->result_status.ok()) {
207+
std::move(on_done_fn)(this, work_item->dest_buffer, work_item->work);
208+
} else {
209+
std::move(on_done_fn)(this, work_item->result_status, work_item->work);
210+
}
210211
}
211212
mu_.lock();
212213
currently_flushing_ = false;

xla/python/transfer/streaming_ifrt_test.cc

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -165,20 +165,23 @@ TEST(PremappedCopierState, FreeCycle) {
165165
TF_ASSERT_OK_AND_ASSIGN(
166166
auto scratch, AllocateAndMapPjrtMemory(pjrt_client, 1024 * 1024 * 16));
167167
auto cstate = std::make_shared<PremappedCopierState>(scratch, 4, 4096);
168-
void* buffer_to_return = nullptr;
169-
cstate->ScheduleCopy({/*copy_fn=*/[](void* dst, int64_t offset,
170-
int64_t transfer_size) -> xla::Future<> {
171-
return xla::Future<>(absl::OkStatus());
172-
},
173-
/*buffer_id=*/0,
174-
/*offset=*/100,
175-
/*size=*/100},
176-
[&buffer_to_return](PremappedCopierState* state,
177-
absl::StatusOr<void*> buf,
178-
const DmaCopyChunk& chunk) {
179-
TF_CHECK_OK(buf.status());
180-
buffer_to_return = buf.value();
181-
});
168+
std::vector<void*> buffers_to_return;
169+
for (size_t i = 0; i < 2; ++i) {
170+
cstate->ScheduleCopy(
171+
{/*copy_fn=*/[](void* dst, int64_t offset,
172+
int64_t transfer_size) -> xla::Future<> {
173+
return xla::Future<>(absl::OkStatus());
174+
},
175+
/*buffer_id=*/0,
176+
/*offset=*/100,
177+
/*size=*/100},
178+
[&buffers_to_return](PremappedCopierState* state,
179+
absl::StatusOr<void*> buf,
180+
const DmaCopyChunk& chunk) {
181+
TF_CHECK_OK(buf.status());
182+
buffers_to_return.push_back(buf.value());
183+
});
184+
}
182185
class BufferReturner {
183186
public:
184187
explicit BufferReturner(absl::AnyInvocable<void() &&> on_done)
@@ -190,17 +193,19 @@ TEST(PremappedCopierState, FreeCycle) {
190193
};
191194
cstate->ScheduleCopy(
192195
{/*copy_fn=*/[buffer = std::make_unique<BufferReturner>(
193-
[buffer_to_return, cstate]() {
194-
cstate->ReturnBuffer(buffer_to_return);
196+
[b = buffers_to_return[0], cstate]() {
197+
cstate->ReturnBuffer(b);
195198
})](void* dst, int64_t offset,
196199
int64_t transfer_size) -> xla::Future<> {
197200
return xla::Future<>(absl::OkStatus());
198201
},
199202
/*buffer_id=*/0,
200203
/*offset=*/100,
201204
/*size=*/100},
202-
[](PremappedCopierState* state, absl::StatusOr<void*> buf,
203-
const DmaCopyChunk& chunk) {
205+
[buffer = std::make_unique<BufferReturner>(
206+
[b = buffers_to_return[1], cstate]() { cstate->ReturnBuffer(b); })](
207+
PremappedCopierState* state, absl::StatusOr<void*> buf,
208+
const DmaCopyChunk& chunk) {
204209
TF_CHECK_OK(buf.status());
205210
state->ReturnBuffer(buf.value());
206211
});

0 commit comments

Comments
 (0)