5050namespace tensorstore {
5151namespace internal_kvstore_batch {
5252
53- // Common portion of read request used with `BatchReadEntry`.
53+ // Type trait that checks if a type is a valid `ReadRequest` type as required
54+ // by `BatchReadEntry`. Valid types include any structs with the following
55+ // members:
5456//
55- // This is combined in `ReadRequest` with additional fields only used for some
56- // kvstore implementations.
57+ // Promise<kvstore::ReadResult> promise;
58+ // OptionalByteRangeRequest byte_range;
59+ //
60+ // Additional members may be present, such as
61+ // `kvstore::ReadGenerationConditions` or `kvstore::Key`, but at least those two
62+ // members must be present.
63+ //
64+ // See `ByteRangeReadRequest` and `ByteRangeGenerationReadRequest` below for
65+ // specific types that satisfy this requirement.
66+ template <typename T, typename = void >
67+ struct IsByteRangeReadRequestLike : std::false_type {};
68+
69+ template <typename T>
70+ struct IsByteRangeReadRequestLike <
71+ T, std::void_t <decltype (std::declval<T&>().promise),
72+ decltype (std::declval<T&>().byte_range)>>
73+ : public std::conjunction<
74+ std::is_same<decltype(std::declval<T&>().promise),
75+ Promise<kvstore::ReadResult>>,
76+ std::is_same<decltype(std::declval<T&>().byte_range),
77+ OptionalByteRangeRequest>> {};
78+
79+ template <typename T>
80+ constexpr bool IsByteRangeReadRequestLikeV =
81+ IsByteRangeReadRequestLike<T>::value;
82+
83+ // Common `BatchReadEntry` ReadRequest type.
5784struct ByteRangeReadRequest {
5885 Promise<kvstore::ReadResult> promise;
5986 OptionalByteRangeRequest byte_range;
6087};
6188
62- // Individual read request (entry in batch) used with `BatchReadEntry`.
63- //
64- // Possibly `Member` types include:
89+ // Common `BatchReadEntry` ReadRequest type request with generation conditions.
6590//
66- // - `kvstore::ReadGenerationConditions`
67- // - `kvstore::Key`
68- // - some derived key type, like `uint64_t`.
69- template <typename ... Member>
70- using ReadRequest = std::tuple<ByteRangeReadRequest, Member...>;
91+ // This is used when generation conditions are not included in the batch entry
92+ // key.
93+ struct ByteRangeGenerationReadRequest {
94+ Promise<kvstore::ReadResult> promise;
95+ OptionalByteRangeRequest byte_range;
96+ kvstore::ReadGenerationConditions generation_conditions;
97+ };
7198
7299// Batch of read requests with an aggregate staleness bound, used by
73100// `BatchReadEntry`.
74101//
75102// The aggregate staleness bound is set to the maximum staleness bound of all
76103// individual requests.
104+ //
105+ // \tparam RequestType Must satisfy `IsByteRangeReadRequestLikeV`.
77106template <typename RequestType>
78107struct RequestBatch {
108+ static_assert (IsByteRangeReadRequestLikeV<RequestType>);
109+
79110 using Request = RequestType;
111+
80112 absl::Time staleness_bound = absl::InfinitePast();
81113 absl::InlinedVector<Request, 1 > requests;
82114
@@ -123,10 +155,13 @@ template <typename DerivedDriver, typename RequestType,
123155 typename ... BatchEntryKeyMember>
124156class BatchReadEntry : public Batch ::Impl::Entry {
125157 public:
158+ static_assert (IsByteRangeReadRequestLikeV<RequestType>);
159+
126160 using Driver = DerivedDriver;
127161 using BatchEntryKey =
128162 std::tuple<internal::IntrusivePtr<DerivedDriver>, BatchEntryKeyMember...>;
129163 using Request = RequestType;
164+
130165 using KeyParam =
131166 std::tuple<DerivedDriver*, KeyParamType<BatchEntryKeyMember>...>;
132167
@@ -204,7 +239,7 @@ class BatchReadEntry : public Batch::Impl::Entry {
204239 absl::Mutex mutex_;
205240
206241 void AddRequest (absl::Time staleness_bound, Request&& request) {
207- absl::MutexLock lock (& mutex_);
242+ absl::MutexLock lock (mutex_);
208243 request_batch.AddRequest (staleness_bound, std::move (request));
209244 }
210245};
@@ -217,10 +252,13 @@ void SetCommonResult(span<const Request> requests,
217252 Result<kvstore::ReadResult>&& result) {
218253 if (requests.empty ()) return ;
219254 for (size_t i = 1 ; i < requests.size (); ++i) {
220- std::get<ByteRangeReadRequest>(requests[i]).promise .SetResult (result);
255+ if (requests[i].promise .result_needed ()) {
256+ requests[i].promise .SetResult (result);
257+ }
258+ }
259+ if (requests[0 ].promise .result_needed ()) {
260+ requests[0 ].promise .SetResult (std::move (result));
221261 }
222- std::get<ByteRangeReadRequest>(requests[0 ])
223- .promise .SetResult (std::move (result));
224262}
225263template <typename Requests>
226264void SetCommonResult (const Requests& requests,
@@ -231,11 +269,10 @@ void SetCommonResult(const Requests& requests,
231269
232270template <typename Request>
233271void SortRequestsByStartByte (span<Request> requests) {
234- std::sort (
235- requests.begin (), requests.end (), [](const Request& a, const Request& b) {
236- return std::get<ByteRangeReadRequest>(a).byte_range .inclusive_min <
237- std::get<ByteRangeReadRequest>(b).byte_range .inclusive_min ;
238- });
272+ std::sort (requests.begin (), requests.end (),
273+ [](const Request& a, const Request& b) {
274+ return a.byte_range .inclusive_min < b.byte_range .inclusive_min ;
275+ });
239276}
240277
241278// Resolves coalesced requests with the appropriate cord subranges.
@@ -244,19 +281,18 @@ void ResolveCoalescedRequests(ByteRange coalesced_byte_range,
244281 span<Request> coalesced_requests,
245282 kvstore::ReadResult&& read_result) {
246283 for (auto & request : coalesced_requests) {
247- auto & byte_range_request = std::get<ByteRangeReadRequest>(request);
248284 kvstore::ReadResult sub_read_result;
249285 sub_read_result.stamp = read_result.stamp ;
250286 sub_read_result.state = read_result.state ;
251287 if (read_result.state == kvstore::ReadResult::kValue ) {
252288 ABSL_DCHECK_EQ (coalesced_byte_range.size (), read_result.value .size ());
253- int64_t request_start = byte_range_request. byte_range . inclusive_min -
254- coalesced_byte_range.inclusive_min ;
255- int64_t request_size = byte_range_request .byte_range .size ();
289+ int64_t request_start =
290+ request. byte_range . inclusive_min - coalesced_byte_range.inclusive_min ;
291+ int64_t request_size = request .byte_range .size ();
256292 sub_read_result.value =
257293 read_result.value .Subcord (request_start, request_size);
258294 }
259- byte_range_request .promise .SetResult (std::move (sub_read_result));
295+ request .promise .SetResult (std::move (sub_read_result));
260296 }
261297}
262298
@@ -277,19 +313,17 @@ void ResolveCoalescedRequests(ByteRange coalesced_byte_range,
277313template <typename Request, typename Predicate, typename Callback>
278314void ForEachCoalescedRequest (span<Request> requests, Predicate predicate,
279315 Callback callback) {
316+ static_assert (IsByteRangeReadRequestLikeV<Request>);
317+
280318 SortRequestsByStartByte (requests);
281319
282320 size_t request_i = 0 ;
283321 while (request_i < requests.size ()) {
284- auto coalesced_byte_range =
285- std::get<ByteRangeReadRequest>(requests[request_i])
286- .byte_range .AsByteRange ();
322+ auto coalesced_byte_range = requests[request_i].byte_range .AsByteRange ();
287323 size_t end_request_i;
288324 for (end_request_i = request_i + 1 ; end_request_i < requests.size ();
289325 ++end_request_i) {
290- auto next_byte_range =
291- std::get<ByteRangeReadRequest>(requests[end_request_i])
292- .byte_range .AsByteRange ();
326+ auto next_byte_range = requests[end_request_i].byte_range .AsByteRange ();
293327 if (next_byte_range.inclusive_min < coalesced_byte_range.exclusive_max ||
294328 predicate (coalesced_byte_range, next_byte_range.inclusive_min )) {
295329 coalesced_byte_range.exclusive_max = std::max (
@@ -314,12 +348,9 @@ void ForEachCoalescedRequest(span<Request> requests, Predicate predicate,
314348template <typename Request>
315349bool ValidateRequestGeneration (Request& request,
316350 const TimestampedStorageGeneration& stamp) {
317- auto & byte_range_request = std::get<ByteRangeReadRequest>(request);
318- if (!byte_range_request.promise .result_needed ()) return false ;
319- if (!std::get<kvstore::ReadGenerationConditions>(request).Matches (
320- stamp.generation )) {
321- byte_range_request.promise .SetResult (
322- kvstore::ReadResult::Unspecified (stamp));
351+ if (!request.promise .result_needed ()) return false ;
352+ if (!request.generation_conditions .Matches (stamp.generation )) {
353+ request.promise .SetResult (kvstore::ReadResult::Unspecified (stamp));
323354 return false ;
324355 }
325356 return true ;
@@ -334,14 +365,15 @@ bool ValidateRequestGeneration(Request& request,
334365template <typename Request>
335366bool ValidateRequestGenerationAndByteRange (
336367 Request& request, const TimestampedStorageGeneration& stamp, int64_t size) {
368+ static_assert (IsByteRangeReadRequestLikeV<Request>);
369+ static_assert (
370+ std::is_member_pointer<decltype (&Request::generation_conditions)>::value);
337371 if (!ValidateRequestGeneration (request, stamp)) {
338372 return false ;
339373 }
340- auto & byte_range_request = std::get<ByteRangeReadRequest>(request);
341374 TENSORSTORE_ASSIGN_OR_RETURN (
342- byte_range_request.byte_range ,
343- byte_range_request.byte_range .Validate (size),
344- (byte_range_request.promise .SetResult (std::move (_)), false ));
375+ request.byte_range , request.byte_range .Validate (size),
376+ (request.promise .SetResult (std::move (_)), false ));
345377 return true ;
346378}
347379
0 commit comments