@@ -165,20 +165,23 @@ TEST(PremappedCopierState, FreeCycle) {
165165 TF_ASSERT_OK_AND_ASSIGN (
166166 auto scratch, AllocateAndMapPjrtMemory (pjrt_client, 1024 * 1024 * 16 ));
167167 auto cstate = std::make_shared<PremappedCopierState>(scratch, 4 , 4096 );
168- void * buffer_to_return = nullptr ;
169- cstate->ScheduleCopy ({/* copy_fn=*/ [](void * dst, int64_t offset,
170- int64_t transfer_size) -> xla::Future<> {
171- return xla::Future<>(absl::OkStatus ());
172- },
173- /* buffer_id=*/ 0 ,
174- /* offset=*/ 100 ,
175- /* size=*/ 100 },
176- [&buffer_to_return](PremappedCopierState* state,
177- absl::StatusOr<void *> buf,
178- const DmaCopyChunk& chunk) {
179- TF_CHECK_OK (buf.status ());
180- buffer_to_return = buf.value ();
181- });
168+ std::vector<void *> buffers_to_return;
169+ for (size_t i = 0 ; i < 2 ; ++i) {
170+ cstate->ScheduleCopy (
171+ {/* copy_fn=*/ [](void * dst, int64_t offset,
172+ int64_t transfer_size) -> xla::Future<> {
173+ return xla::Future<>(absl::OkStatus ());
174+ },
175+ /* buffer_id=*/ 0 ,
176+ /* offset=*/ 100 ,
177+ /* size=*/ 100 },
178+ [&buffers_to_return](PremappedCopierState* state,
179+ absl::StatusOr<void *> buf,
180+ const DmaCopyChunk& chunk) {
181+ TF_CHECK_OK (buf.status ());
182+ buffers_to_return.push_back (buf.value ());
183+ });
184+ }
182185 class BufferReturner {
183186 public:
184187 explicit BufferReturner (absl::AnyInvocable<void () &&> on_done)
@@ -190,17 +193,19 @@ TEST(PremappedCopierState, FreeCycle) {
190193 };
191194 cstate->ScheduleCopy (
192195 {/* copy_fn=*/ [buffer = std::make_unique<BufferReturner>(
193- [buffer_to_return , cstate]() {
194- cstate->ReturnBuffer (buffer_to_return );
196+ [b = buffers_to_return[ 0 ] , cstate]() {
197+ cstate->ReturnBuffer (b );
195198 })](void * dst, int64_t offset,
196199 int64_t transfer_size) -> xla::Future<> {
197200 return xla::Future<>(absl::OkStatus ());
198201 },
199202 /* buffer_id=*/ 0 ,
200203 /* offset=*/ 100 ,
201204 /* size=*/ 100 },
202- [](PremappedCopierState* state, absl::StatusOr<void *> buf,
203- const DmaCopyChunk& chunk) {
205+ [buffer = std::make_unique<BufferReturner>(
206+ [b = buffers_to_return[1 ], cstate]() { cstate->ReturnBuffer (b); })](
207+ PremappedCopierState* state, absl::StatusOr<void *> buf,
208+ const DmaCopyChunk& chunk) {
204209 TF_CHECK_OK (buf.status ());
205210 state->ReturnBuffer (buf.value ());
206211 });
0 commit comments