Skip to content

Commit 6cd8967

Browse files
laramielcopybara-github
authored andcommitted
decompose/encapsulate PartitionIndexTransformIterator in ReadOperationState for ChunkCache
PiperOrigin-RevId: 805154398 Change-Id: I2872780c42d8fdefeee7f834277a7f44dc66ab6d
1 parent 7b96ffc commit 6cd8967

File tree

18 files changed

+158
-112
lines changed

18 files changed

+158
-112
lines changed

tensorstore/driver/BUILD

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -400,10 +400,10 @@ tensorstore_cc_library(
400400
name = "chunk_receiver_utils",
401401
hdrs = ["chunk_receiver_utils.h"],
402402
deps = [
403-
":chunk",
404403
"//tensorstore/index_space:index_transform",
405404
"//tensorstore/internal:intrusive_ptr",
406405
"//tensorstore/util:future",
406+
"//tensorstore/util/execution",
407407
"//tensorstore/util/execution:any_receiver",
408408
"//tensorstore/util/execution:flow_sender_operation_state",
409409
"@abseil-cpp//absl/status",

tensorstore/driver/chunk_receiver_utils.h

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,37 +15,36 @@
1515
#ifndef TENSORSTORE_INTERNAL_CHUNK_RECEIVER_UTILS_H_
1616
#define TENSORSTORE_INTERNAL_CHUNK_RECEIVER_UTILS_H_
1717

18+
#include <utility>
19+
1820
#include "absl/status/status.h"
19-
#include "tensorstore/driver/chunk.h"
2021
#include "tensorstore/index_space/index_transform.h"
2122
#include "tensorstore/internal/intrusive_ptr.h"
2223
#include "tensorstore/util/execution/any_receiver.h"
24+
#include "tensorstore/util/execution/execution.h"
2325
#include "tensorstore/util/execution/flow_sender_operation_state.h"
2426
#include "tensorstore/util/future.h"
2527

2628
namespace tensorstore {
2729
namespace internal {
2830

29-
template <typename ChunkT>
30-
struct ChunkOperationState
31-
: public FlowSenderOperationState<ChunkT, IndexTransform<>> {
32-
using ChunkType = ChunkT;
33-
using Base = FlowSenderOperationState<ChunkT, IndexTransform<>>;
34-
35-
using Base::Base;
36-
};
37-
3831
// Forwarding receiver which satisfies `ReadChunkReceiver` or
3932
// `WriteChunkReceiver`. The starting/stopping/error/done parts of the protocol
4033
// are handled by the future, so this only forwards set_error and set_value
4134
// calls.
42-
template <typename StateType>
35+
template <typename ChunkType, typename StateType>
4336
struct ForwardingChunkOperationReceiver {
44-
using ChunkType = typename StateType::ChunkType;
4537
IntrusivePtr<StateType> state;
4638
IndexTransform<> cell_transform;
4739
FutureCallbackRegistration cancel_registration;
4840

41+
// StateType must be a FlowSenderOperationState or a subclass of that.
42+
static_assert(
43+
std::is_same_v<FlowSenderOperationState<ChunkType, IndexTransform<>>,
44+
StateType> ||
45+
std::is_base_of_v<FlowSenderOperationState<ChunkType, IndexTransform<>>,
46+
StateType>);
47+
4948
void set_starting(AnyCancelReceiver cancel) {
5049
cancel_registration =
5150
state->promise.ExecuteWhenNotNeeded(std::move(cancel));

tensorstore/driver/neuroglancer_precomputed/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,7 @@ tensorstore_cc_library(
259259
"//tensorstore/util:str_cat",
260260
"//tensorstore/util:unit",
261261
"//tensorstore/util/execution:any_receiver",
262+
"//tensorstore/util/execution:flow_sender_operation_state",
262263
"//tensorstore/util/garbage_collection",
263264
"@abseil-cpp//absl/algorithm:container",
264265
"@abseil-cpp//absl/container:inlined_vector",

tensorstore/driver/neuroglancer_precomputed/driver.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
#include "tensorstore/util/dimension_set.h"
8787
#include "tensorstore/util/division.h"
8888
#include "tensorstore/util/execution/any_receiver.h"
89+
#include "tensorstore/util/execution/flow_sender_operation_state.h"
8990
#include "tensorstore/util/future.h"
9091
#include "tensorstore/util/garbage_collection/fwd.h"
9192
#include "tensorstore/util/result.h"
@@ -760,9 +761,10 @@ class RegularlyShardedDataCache : public ShardedDataCache {
760761
shard_shape_in_elements[dim] =
761762
scale.chunk_sizes[0][dim] * hierarchy_.shard_shape_in_chunks[dim];
762763
}
763-
using State = internal::ChunkOperationState<ChunkType>;
764+
using State =
765+
internal::FlowSenderOperationState<ChunkType, IndexTransform<>>;
764766
using ForwardingReceiver =
765-
internal::ForwardingChunkOperationReceiver<State>;
767+
internal::ForwardingChunkOperationReceiver<ChunkType, State>;
766768
auto state = internal::MakeIntrusivePtr<State>(std::move(receiver));
767769

768770
auto status = [&]() -> absl::Status {

tensorstore/driver/stack/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ tensorstore_cc_library(
6969
"//tensorstore/util:status",
7070
"//tensorstore/util:str_cat",
7171
"//tensorstore/util/execution:any_receiver",
72+
"//tensorstore/util/execution:flow_sender_operation_state",
7273
"//tensorstore/util/garbage_collection",
7374
"@abseil-cpp//absl/container:flat_hash_map",
7475
"@abseil-cpp//absl/container:flat_hash_set",

tensorstore/driver/stack/driver.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
#include "tensorstore/transaction.h"
6767
#include "tensorstore/util/dimension_set.h"
6868
#include "tensorstore/util/execution/any_receiver.h"
69+
#include "tensorstore/util/execution/flow_sender_operation_state.h"
6970
#include "tensorstore/util/executor.h"
7071
#include "tensorstore/util/future.h"
7172
#include "tensorstore/util/iterate_over_index_range.h"
@@ -789,16 +790,18 @@ struct OpenLayerOp {
789790
// Asynchronous state for StackDriver::{Read,Write} that maintains reference
790791
// counts while the read/write operation is in progress.
791792
template <typename ChunkType>
792-
struct ReadOrWriteState : public internal::ChunkOperationState<ChunkType> {
793+
struct ReadOrWriteState
794+
: public internal::FlowSenderOperationState<ChunkType, IndexTransform<>> {
793795
static constexpr ReadWriteMode kMode = std::is_same_v<ChunkType, ReadChunk>
794796
? ReadWriteMode::read
795797
: ReadWriteMode::write;
796798
using RequestType = std::conditional_t<std::is_same_v<ChunkType, ReadChunk>,
797799
internal::Driver::ReadRequest,
798800
internal::Driver::WriteRequest>;
799-
using Base = internal::ChunkOperationState<ChunkType>;
801+
using Base = internal::FlowSenderOperationState<ChunkType, IndexTransform<>>;
800802
using State = ReadOrWriteState<ChunkType>;
801-
using ForwardingReceiver = internal::ForwardingChunkOperationReceiver<State>;
803+
using ForwardingReceiver =
804+
internal::ForwardingChunkOperationReceiver<ChunkType, State>;
802805

803806
using Base::Base;
804807

tensorstore/driver/zarr3/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ tensorstore_cc_library(
236236
"//tensorstore/util:span",
237237
"//tensorstore/util:status",
238238
"//tensorstore/util/execution:any_receiver",
239+
"//tensorstore/util/execution:flow_sender_operation_state",
239240
"@abseil-cpp//absl/container:inlined_vector",
240241
"@abseil-cpp//absl/status",
241242
"@abseil-cpp//absl/strings:cord",

tensorstore/driver/zarr3/chunk_cache.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
#include "tensorstore/transaction.h"
6161
#include "tensorstore/util/division.h"
6262
#include "tensorstore/util/execution/any_receiver.h"
63+
#include "tensorstore/util/execution/flow_sender_operation_state.h"
6364
#include "tensorstore/util/future.h"
6465
#include "tensorstore/util/result.h"
6566
#include "tensorstore/util/span.h"
@@ -261,8 +262,9 @@ void ShardedReadOrWrite(
261262
const auto& grid = self.grid();
262263
const auto& component_spec = grid.components[0];
263264

264-
using State = internal::ChunkOperationState<ChunkType>;
265-
using ForwardingReceiver = internal::ForwardingChunkOperationReceiver<State>;
265+
using State = internal::FlowSenderOperationState<ChunkType, IndexTransform<>>;
266+
using ForwardingReceiver =
267+
internal::ForwardingChunkOperationReceiver<ChunkType, State>;
266268
span<const Index> chunk_shape = grid.chunk_shape;
267269
span<const DimensionIndex> chunked_to_cell_dimensions =
268270
component_spec.chunked_to_cell_dimensions;

tensorstore/internal/cache/BUILD

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,6 @@ tensorstore_cc_library(
338338
"//tensorstore:read_write_options",
339339
"//tensorstore:transaction",
340340
"//tensorstore/driver:chunk",
341-
"//tensorstore/driver:chunk_receiver_utils",
342341
"//tensorstore/driver:read_request",
343342
"//tensorstore/driver:write_request",
344343
"//tensorstore/index_space:index_transform",
@@ -350,7 +349,6 @@ tensorstore_cc_library(
350349
"//tensorstore/internal:intrusive_ptr",
351350
"//tensorstore/internal:lock_collection",
352351
"//tensorstore/internal:memory",
353-
"//tensorstore/internal:mutex",
354352
"//tensorstore/internal:nditerable",
355353
"//tensorstore/internal:regular_grid",
356354
"//tensorstore/internal/metrics",
@@ -363,6 +361,7 @@ tensorstore_cc_library(
363361
"//tensorstore/util:status",
364362
"//tensorstore/util:str_cat",
365363
"//tensorstore/util/execution",
364+
"//tensorstore/util/execution:flow_sender_operation_state",
366365
"@abseil-cpp//absl/base:core_headers",
367366
"@abseil-cpp//absl/container:inlined_vector",
368367
"@abseil-cpp//absl/log:absl_log",

tensorstore/internal/cache/chunk_cache.cc

Lines changed: 99 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
#include "tensorstore/array.h"
3232
#include "tensorstore/box.h"
3333
#include "tensorstore/driver/chunk.h"
34-
#include "tensorstore/driver/chunk_receiver_utils.h"
3534
#include "tensorstore/index.h"
3635
#include "tensorstore/index_space/index_transform.h"
3736
#include "tensorstore/index_space/transformed_array.h"
@@ -46,14 +45,14 @@
4645
#include "tensorstore/internal/memory.h"
4746
#include "tensorstore/internal/metrics/counter.h"
4847
#include "tensorstore/internal/metrics/metadata.h"
49-
#include "tensorstore/internal/mutex.h"
5048
#include "tensorstore/internal/nditerable.h"
5149
#include "tensorstore/internal/regular_grid.h"
5250
#include "tensorstore/kvstore/generation.h"
5351
#include "tensorstore/rank.h"
5452
#include "tensorstore/read_write_options.h"
5553
#include "tensorstore/transaction.h"
5654
#include "tensorstore/util/execution/execution.h"
55+
#include "tensorstore/util/execution/flow_sender_operation_state.h"
5756
#include "tensorstore/util/executor.h"
5857
#include "tensorstore/util/future.h"
5958
#include "tensorstore/util/result.h"
@@ -413,74 +412,110 @@ struct WriteChunkImpl {
413412
}
414413
};
415414

416-
} // namespace
417-
418-
void ChunkCache::Read(ReadRequest request, ReadChunkReceiver receiver) {
419-
assert(request.component_index >= 0 &&
420-
request.component_index < grid().components.size());
421-
const auto& component_spec = grid().components[request.component_index];
422-
// Shared state used while `Read` is in progress.
423-
using ReadOperationState = ChunkOperationState<ReadChunk>;
424-
425-
assert(component_spec.chunked_to_cell_dimensions.size() ==
426-
grid().chunk_shape.size());
427-
auto state = MakeIntrusivePtr<ReadOperationState>(std::move(receiver));
428-
internal_grid_partition::RegularGridRef regular_grid{grid().chunk_shape};
415+
// Shared state used while `Read` is in progress.
416+
//
417+
// Note: The `ReadOperationState::request_` member may contain a reference
418+
// to a batch, and holding the batch will keep the read operation from
419+
// completing until the batch is no longer referenced.
420+
class ReadOperationState : public AtomicReferenceCount<ReadOperationState> {
421+
public:
422+
using ReadCompletionState =
423+
internal::FlowSenderOperationState<ReadChunk, IndexTransform<>>;
424+
using BaseReceiver = ReadCompletionState::BaseReceiver;
425+
426+
explicit ReadOperationState(BaseReceiver&& receiver, ChunkCache& self,
427+
ChunkCache::ReadRequest&& request)
428+
: completion_(MakeIntrusivePtr<ReadCompletionState>(std::move(receiver))),
429+
self_(self),
430+
request_(std::move(request)),
431+
regular_grid_(self_.grid().chunk_shape),
432+
iterator_(self_.grid()
433+
.components[request_.component_index]
434+
.chunked_to_cell_dimensions,
435+
regular_grid_, request_.transform) {}
436+
437+
absl::Status InitiateRead() {
438+
num_reads.Increment();
439+
TENSORSTORE_ASSIGN_OR_RETURN(
440+
auto cell_to_source,
441+
ComposeTransforms(request_.transform, iterator_.cell_transform()));
442+
auto entry =
443+
GetEntryForGridCell(self_, iterator_.output_grid_cell_indices());
444+
// Arrange to call `set_value` on the receiver with a `ReadChunk`
445+
// corresponding to this grid cell once the read request completes
446+
// successfully.
447+
ReadChunk chunk;
448+
chunk.transform = std::move(cell_to_source);
449+
Future<const void> read_future;
450+
const auto get_cache_read_request = [&] {
451+
AsyncCache::AsyncCacheReadRequest cache_request;
452+
cache_request.staleness_bound = request_.staleness_bound;
453+
cache_request.batch = request_.batch;
454+
return cache_request;
455+
};
456+
if (request_.transaction) {
457+
TENSORSTORE_ASSIGN_OR_RETURN(
458+
auto node, GetTransactionNode(*entry, request_.transaction));
459+
read_future = node->IsUnconditional()
460+
? MakeReadyFuture()
461+
: node->Read(get_cache_read_request());
462+
chunk.impl =
463+
ReadChunkTransactionImpl{request_.component_index, std::move(node),
464+
request_.fill_missing_data_reads};
465+
} else {
466+
read_future = entry->Read(get_cache_read_request());
467+
chunk.impl = ReadChunkImpl{request_.component_index, std::move(entry),
468+
request_.fill_missing_data_reads};
469+
}
470+
LinkValue(
471+
[completion = completion_, chunk = std::move(chunk),
472+
cell_transform = IndexTransform<>(iterator_.cell_transform())](
473+
Promise<void> promise, ReadyFuture<const void> future) mutable {
474+
completion->YieldValue(std::move(chunk), std::move(cell_transform));
475+
},
476+
completion_->promise, std::move(read_future));
477+
return absl::OkStatus();
478+
}
429479

430-
auto status = [&]() -> absl::Status {
431-
internal_grid_partition::PartitionIndexTransformIterator iterator(
432-
component_spec.chunked_to_cell_dimensions, regular_grid,
433-
request.transform);
434-
TENSORSTORE_RETURN_IF_ERROR(iterator.Init());
480+
absl::Status IteratorLoop() {
481+
TENSORSTORE_RETURN_IF_ERROR(iterator_.Init());
435482

436-
while (!iterator.AtEnd()) {
437-
if (state->cancelled()) {
483+
while (!iterator_.AtEnd()) {
484+
if (cancelled()) {
438485
return absl::CancelledError("");
439486
}
440-
num_reads.Increment();
441-
TENSORSTORE_ASSIGN_OR_RETURN(
442-
auto cell_to_source,
443-
ComposeTransforms(request.transform, iterator.cell_transform()));
444-
auto entry =
445-
GetEntryForGridCell(*this, iterator.output_grid_cell_indices());
446-
// Arrange to call `set_value` on the receiver with a `ReadChunk`
447-
// corresponding to this grid cell once the read request completes
448-
// successfully.
449-
ReadChunk chunk;
450-
chunk.transform = std::move(cell_to_source);
451-
Future<const void> read_future;
452-
const auto get_cache_read_request = [&] {
453-
AsyncCache::AsyncCacheReadRequest cache_request;
454-
cache_request.staleness_bound = request.staleness_bound;
455-
cache_request.batch = request.batch;
456-
return cache_request;
457-
};
458-
if (request.transaction) {
459-
TENSORSTORE_ASSIGN_OR_RETURN(
460-
auto node, GetTransactionNode(*entry, request.transaction));
461-
read_future = node->IsUnconditional()
462-
? MakeReadyFuture()
463-
: node->Read(get_cache_read_request());
464-
chunk.impl =
465-
ReadChunkTransactionImpl{request.component_index, std::move(node),
466-
request.fill_missing_data_reads};
467-
} else {
468-
read_future = entry->Read(get_cache_read_request());
469-
chunk.impl = ReadChunkImpl{request.component_index, std::move(entry),
470-
request.fill_missing_data_reads};
471-
}
472-
LinkValue(
473-
[state, chunk = std::move(chunk),
474-
cell_transform = IndexTransform<>(iterator.cell_transform())](
475-
Promise<void> promise, ReadyFuture<const void> future) mutable {
476-
execution::set_value(state->shared_receiver->receiver,
477-
std::move(chunk), std::move(cell_transform));
478-
},
479-
state->promise, std::move(read_future));
480-
iterator.Advance();
487+
TENSORSTORE_RETURN_IF_ERROR(InitiateRead());
488+
iterator_.Advance();
481489
}
482490
return absl::OkStatus();
483-
}();
491+
}
492+
493+
bool cancelled() const { return completion_->cancelled(); }
494+
495+
void SetError(absl::Status status) {
496+
completion_->SetError(std::move(status));
497+
}
498+
499+
private:
500+
IntrusivePtr<ReadCompletionState> completion_;
501+
ChunkCache& self_;
502+
ChunkCache::ReadRequest request_;
503+
internal_grid_partition::RegularGridRef regular_grid_;
504+
internal_grid_partition::PartitionIndexTransformIterator iterator_;
505+
};
506+
507+
} // namespace
508+
509+
void ChunkCache::Read(ReadRequest request, ReadChunkReceiver receiver) {
510+
[[maybe_unused]] const auto& grid = this->grid();
511+
assert(request.component_index >= 0 &&
512+
request.component_index < grid.components.size());
513+
assert(grid.components[request.component_index]
514+
.chunked_to_cell_dimensions.size() == grid.chunk_shape.size());
515+
516+
auto state = MakeIntrusivePtr<ReadOperationState>(std::move(receiver), *this,
517+
std::move(request));
518+
auto status = state->IteratorLoop();
484519
if (!status.ok()) {
485520
state->SetError(std::move(status));
486521
}

0 commit comments

Comments
 (0)