|
31 | 31 | #include "tensorstore/array.h" |
32 | 32 | #include "tensorstore/box.h" |
33 | 33 | #include "tensorstore/driver/chunk.h" |
34 | | -#include "tensorstore/driver/chunk_receiver_utils.h" |
35 | 34 | #include "tensorstore/index.h" |
36 | 35 | #include "tensorstore/index_space/index_transform.h" |
37 | 36 | #include "tensorstore/index_space/transformed_array.h" |
|
46 | 45 | #include "tensorstore/internal/memory.h" |
47 | 46 | #include "tensorstore/internal/metrics/counter.h" |
48 | 47 | #include "tensorstore/internal/metrics/metadata.h" |
49 | | -#include "tensorstore/internal/mutex.h" |
50 | 48 | #include "tensorstore/internal/nditerable.h" |
51 | 49 | #include "tensorstore/internal/regular_grid.h" |
52 | 50 | #include "tensorstore/kvstore/generation.h" |
53 | 51 | #include "tensorstore/rank.h" |
54 | 52 | #include "tensorstore/read_write_options.h" |
55 | 53 | #include "tensorstore/transaction.h" |
56 | 54 | #include "tensorstore/util/execution/execution.h" |
| 55 | +#include "tensorstore/util/execution/flow_sender_operation_state.h" |
57 | 56 | #include "tensorstore/util/executor.h" |
58 | 57 | #include "tensorstore/util/future.h" |
59 | 58 | #include "tensorstore/util/result.h" |
@@ -413,74 +412,110 @@ struct WriteChunkImpl { |
413 | 412 | } |
414 | 413 | }; |
415 | 414 |
|
416 | | -} // namespace |
417 | | - |
418 | | -void ChunkCache::Read(ReadRequest request, ReadChunkReceiver receiver) { |
419 | | - assert(request.component_index >= 0 && |
420 | | - request.component_index < grid().components.size()); |
421 | | - const auto& component_spec = grid().components[request.component_index]; |
422 | | - // Shared state used while `Read` is in progress. |
423 | | - using ReadOperationState = ChunkOperationState<ReadChunk>; |
424 | | - |
425 | | - assert(component_spec.chunked_to_cell_dimensions.size() == |
426 | | - grid().chunk_shape.size()); |
427 | | - auto state = MakeIntrusivePtr<ReadOperationState>(std::move(receiver)); |
428 | | - internal_grid_partition::RegularGridRef regular_grid{grid().chunk_shape}; |
| 415 | +// Shared state used while `Read` is in progress. |
| 416 | +// |
| 417 | +// Note: The `ReadOperationState::request_` member may contain a reference |
| 418 | +// to a batch, and holding the batch will keep the read operation from |
| 419 | +// completing until the batch is no longer referenced. |
| 420 | +class ReadOperationState : public AtomicReferenceCount<ReadOperationState> { |
| 421 | + public: |
| 422 | + using ReadCompletionState = |
| 423 | + internal::FlowSenderOperationState<ReadChunk, IndexTransform<>>; |
| 424 | + using BaseReceiver = ReadCompletionState::BaseReceiver; |
| 425 | + |
| 426 | + explicit ReadOperationState(BaseReceiver&& receiver, ChunkCache& self, |
| 427 | + ChunkCache::ReadRequest&& request) |
| 428 | + : completion_(MakeIntrusivePtr<ReadCompletionState>(std::move(receiver))), |
| 429 | + self_(self), |
| 430 | + request_(std::move(request)), |
| 431 | + regular_grid_(self_.grid().chunk_shape), |
| 432 | + iterator_(self_.grid() |
| 433 | + .components[request_.component_index] |
| 434 | + .chunked_to_cell_dimensions, |
| 435 | + regular_grid_, request_.transform) {} |
| 436 | + |
| 437 | + absl::Status InitiateRead() { |
| 438 | + num_reads.Increment(); |
| 439 | + TENSORSTORE_ASSIGN_OR_RETURN( |
| 440 | + auto cell_to_source, |
| 441 | + ComposeTransforms(request_.transform, iterator_.cell_transform())); |
| 442 | + auto entry = |
| 443 | + GetEntryForGridCell(self_, iterator_.output_grid_cell_indices()); |
| 444 | + // Arrange to call `set_value` on the receiver with a `ReadChunk` |
| 445 | + // corresponding to this grid cell once the read request completes |
| 446 | + // successfully. |
| 447 | + ReadChunk chunk; |
| 448 | + chunk.transform = std::move(cell_to_source); |
| 449 | + Future<const void> read_future; |
| 450 | + const auto get_cache_read_request = [&] { |
| 451 | + AsyncCache::AsyncCacheReadRequest cache_request; |
| 452 | + cache_request.staleness_bound = request_.staleness_bound; |
| 453 | + cache_request.batch = request_.batch; |
| 454 | + return cache_request; |
| 455 | + }; |
| 456 | + if (request_.transaction) { |
| 457 | + TENSORSTORE_ASSIGN_OR_RETURN( |
| 458 | + auto node, GetTransactionNode(*entry, request_.transaction)); |
| 459 | + read_future = node->IsUnconditional() |
| 460 | + ? MakeReadyFuture() |
| 461 | + : node->Read(get_cache_read_request()); |
| 462 | + chunk.impl = |
| 463 | + ReadChunkTransactionImpl{request_.component_index, std::move(node), |
| 464 | + request_.fill_missing_data_reads}; |
| 465 | + } else { |
| 466 | + read_future = entry->Read(get_cache_read_request()); |
| 467 | + chunk.impl = ReadChunkImpl{request_.component_index, std::move(entry), |
| 468 | + request_.fill_missing_data_reads}; |
| 469 | + } |
| 470 | + LinkValue( |
| 471 | + [completion = completion_, chunk = std::move(chunk), |
| 472 | + cell_transform = IndexTransform<>(iterator_.cell_transform())]( |
| 473 | + Promise<void> promise, ReadyFuture<const void> future) mutable { |
| 474 | + completion->YieldValue(std::move(chunk), std::move(cell_transform)); |
| 475 | + }, |
| 476 | + completion_->promise, std::move(read_future)); |
| 477 | + return absl::OkStatus(); |
| 478 | + } |
429 | 479 |
|
430 | | - auto status = [&]() -> absl::Status { |
431 | | - internal_grid_partition::PartitionIndexTransformIterator iterator( |
432 | | - component_spec.chunked_to_cell_dimensions, regular_grid, |
433 | | - request.transform); |
434 | | - TENSORSTORE_RETURN_IF_ERROR(iterator.Init()); |
| 480 | + absl::Status IteratorLoop() { |
| 481 | + TENSORSTORE_RETURN_IF_ERROR(iterator_.Init()); |
435 | 482 |
|
436 | | - while (!iterator.AtEnd()) { |
437 | | - if (state->cancelled()) { |
| 483 | + while (!iterator_.AtEnd()) { |
| 484 | + if (cancelled()) { |
438 | 485 | return absl::CancelledError(""); |
439 | 486 | } |
440 | | - num_reads.Increment(); |
441 | | - TENSORSTORE_ASSIGN_OR_RETURN( |
442 | | - auto cell_to_source, |
443 | | - ComposeTransforms(request.transform, iterator.cell_transform())); |
444 | | - auto entry = |
445 | | - GetEntryForGridCell(*this, iterator.output_grid_cell_indices()); |
446 | | - // Arrange to call `set_value` on the receiver with a `ReadChunk` |
447 | | - // corresponding to this grid cell once the read request completes |
448 | | - // successfully. |
449 | | - ReadChunk chunk; |
450 | | - chunk.transform = std::move(cell_to_source); |
451 | | - Future<const void> read_future; |
452 | | - const auto get_cache_read_request = [&] { |
453 | | - AsyncCache::AsyncCacheReadRequest cache_request; |
454 | | - cache_request.staleness_bound = request.staleness_bound; |
455 | | - cache_request.batch = request.batch; |
456 | | - return cache_request; |
457 | | - }; |
458 | | - if (request.transaction) { |
459 | | - TENSORSTORE_ASSIGN_OR_RETURN( |
460 | | - auto node, GetTransactionNode(*entry, request.transaction)); |
461 | | - read_future = node->IsUnconditional() |
462 | | - ? MakeReadyFuture() |
463 | | - : node->Read(get_cache_read_request()); |
464 | | - chunk.impl = |
465 | | - ReadChunkTransactionImpl{request.component_index, std::move(node), |
466 | | - request.fill_missing_data_reads}; |
467 | | - } else { |
468 | | - read_future = entry->Read(get_cache_read_request()); |
469 | | - chunk.impl = ReadChunkImpl{request.component_index, std::move(entry), |
470 | | - request.fill_missing_data_reads}; |
471 | | - } |
472 | | - LinkValue( |
473 | | - [state, chunk = std::move(chunk), |
474 | | - cell_transform = IndexTransform<>(iterator.cell_transform())]( |
475 | | - Promise<void> promise, ReadyFuture<const void> future) mutable { |
476 | | - execution::set_value(state->shared_receiver->receiver, |
477 | | - std::move(chunk), std::move(cell_transform)); |
478 | | - }, |
479 | | - state->promise, std::move(read_future)); |
480 | | - iterator.Advance(); |
| 487 | + TENSORSTORE_RETURN_IF_ERROR(InitiateRead()); |
| 488 | + iterator_.Advance(); |
481 | 489 | } |
482 | 490 | return absl::OkStatus(); |
483 | | - }(); |
| 491 | + } |
| 492 | + |
| 493 | + bool cancelled() const { return completion_->cancelled(); } |
| 494 | + |
| 495 | + void SetError(absl::Status status) { |
| 496 | + completion_->SetError(std::move(status)); |
| 497 | + } |
| 498 | + |
| 499 | + private: |
| 500 | + IntrusivePtr<ReadCompletionState> completion_; |
| 501 | + ChunkCache& self_; |
| 502 | + ChunkCache::ReadRequest request_; |
| 503 | + internal_grid_partition::RegularGridRef regular_grid_; |
| 504 | + internal_grid_partition::PartitionIndexTransformIterator iterator_; |
| 505 | +}; |
| 506 | + |
| 507 | +} // namespace |
| 508 | + |
| 509 | +void ChunkCache::Read(ReadRequest request, ReadChunkReceiver receiver) { |
| 510 | + [[maybe_unused]] const auto& grid = this->grid(); |
| 511 | + assert(request.component_index >= 0 && |
| 512 | + request.component_index < grid.components.size()); |
| 513 | + assert(grid.components[request.component_index] |
| 514 | + .chunked_to_cell_dimensions.size() == grid.chunk_shape.size()); |
| 515 | + |
| 516 | + auto state = MakeIntrusivePtr<ReadOperationState>(std::move(receiver), *this, |
| 517 | + std::move(request)); |
| 518 | + auto status = state->IteratorLoop(); |
484 | 519 | if (!status.ok()) { |
485 | 520 | state->SetError(std::move(status)); |
486 | 521 | } |
|
0 commit comments