Skip to content

Commit d115425

Browse files
laramielcopybara-github
authored andcommitted
Convert batch Request type from tuple<> to a struct type with traits detection.
While working on other batch requests I found that I wanted to try a ReadRequest type with two string fields. Access to the fields was by type in a few places, and since ReadRequest was implemented as a tuple<...>, duplicate field types are problematic, so I've removed that aspect of ReadRequest. PiperOrigin-RevId: 832726371 Change-Id: Id6ce5c31edb4f7e16fc79f47bd4a1d2e3534dda7
1 parent 052f875 commit d115425

File tree

7 files changed

+195
-194
lines changed

7 files changed

+195
-194
lines changed

tensorstore/kvstore/batch_util.h

Lines changed: 73 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -50,33 +50,65 @@
5050
namespace tensorstore {
5151
namespace internal_kvstore_batch {
5252

53-
// Common portion of read request used with `BatchReadEntry`.
53+
// Type trait that checks if a type is a valid `ReadRequest` type as required
54+
// by `BatchReadEntry`. Valid types include any structs with the following
55+
// members:
5456
//
55-
// This is combined in `ReadRequest` with additional fields only used for some
56-
// kvstore implementations.
57+
// Promise<kvstore::ReadResult> promise;
58+
// OptionalByteRangeRequest byte_range;
59+
//
60+
// Additional members may be present, such as
61+
// `kvstore::ReadGenerationConditions` or `kvstore::Key`, but at least those two
62+
// members must be present.
63+
//
64+
// See `ByteRangeReadRequest` and `ByteRangeGenerationReadRequest` below for
65+
// specific types that satisfy this requirement.
66+
template <typename T, typename = void>
67+
struct IsByteRangeReadRequestLike : std::false_type {};
68+
69+
template <typename T>
70+
struct IsByteRangeReadRequestLike<
71+
T, std::void_t<decltype(std::declval<T&>().promise),
72+
decltype(std::declval<T&>().byte_range)>>
73+
: public std::conjunction<
74+
std::is_same<decltype(std::declval<T&>().promise),
75+
Promise<kvstore::ReadResult>>,
76+
std::is_same<decltype(std::declval<T&>().byte_range),
77+
OptionalByteRangeRequest>> {};
78+
79+
template <typename T>
80+
constexpr bool IsByteRangeReadRequestLikeV =
81+
IsByteRangeReadRequestLike<T>::value;
82+
83+
// Common `BatchReadEntry` ReadRequest type.
5784
struct ByteRangeReadRequest {
5885
Promise<kvstore::ReadResult> promise;
5986
OptionalByteRangeRequest byte_range;
6087
};
6188

62-
// Individual read request (entry in batch) used with `BatchReadEntry`.
63-
//
64-
// Possibly `Member` types include:
89+
// Common `BatchReadEntry` ReadRequest type request with generation conditions.
6590
//
66-
// - `kvstore::ReadGenerationConditions`
67-
// - `kvstore::Key`
68-
// - some derived key type, like `uint64_t`.
69-
template <typename... Member>
70-
using ReadRequest = std::tuple<ByteRangeReadRequest, Member...>;
91+
// This is used when generation conditions are not included in the batch entry
92+
// key.
93+
struct ByteRangeGenerationReadRequest {
94+
Promise<kvstore::ReadResult> promise;
95+
OptionalByteRangeRequest byte_range;
96+
kvstore::ReadGenerationConditions generation_conditions;
97+
};
7198

7299
// Batch of read requests with an aggregate staleness bound, used by
73100
// `BatchReadEntry`.
74101
//
75102
// The aggregate staleness bound is set to the maximum staleness bound of all
76103
// individual requests.
104+
//
105+
// \tparam RequestType Must satisfy `IsByteRangeReadRequestLikeV`.
77106
template <typename RequestType>
78107
struct RequestBatch {
108+
static_assert(IsByteRangeReadRequestLikeV<RequestType>);
109+
79110
using Request = RequestType;
111+
80112
absl::Time staleness_bound = absl::InfinitePast();
81113
absl::InlinedVector<Request, 1> requests;
82114

@@ -123,10 +155,13 @@ template <typename DerivedDriver, typename RequestType,
123155
typename... BatchEntryKeyMember>
124156
class BatchReadEntry : public Batch::Impl::Entry {
125157
public:
158+
static_assert(IsByteRangeReadRequestLikeV<RequestType>);
159+
126160
using Driver = DerivedDriver;
127161
using BatchEntryKey =
128162
std::tuple<internal::IntrusivePtr<DerivedDriver>, BatchEntryKeyMember...>;
129163
using Request = RequestType;
164+
130165
using KeyParam =
131166
std::tuple<DerivedDriver*, KeyParamType<BatchEntryKeyMember>...>;
132167

@@ -204,7 +239,7 @@ class BatchReadEntry : public Batch::Impl::Entry {
204239
absl::Mutex mutex_;
205240

206241
void AddRequest(absl::Time staleness_bound, Request&& request) {
207-
absl::MutexLock lock(&mutex_);
242+
absl::MutexLock lock(mutex_);
208243
request_batch.AddRequest(staleness_bound, std::move(request));
209244
}
210245
};
@@ -217,10 +252,13 @@ void SetCommonResult(span<const Request> requests,
217252
Result<kvstore::ReadResult>&& result) {
218253
if (requests.empty()) return;
219254
for (size_t i = 1; i < requests.size(); ++i) {
220-
std::get<ByteRangeReadRequest>(requests[i]).promise.SetResult(result);
255+
if (requests[i].promise.result_needed()) {
256+
requests[i].promise.SetResult(result);
257+
}
258+
}
259+
if (requests[0].promise.result_needed()) {
260+
requests[0].promise.SetResult(std::move(result));
221261
}
222-
std::get<ByteRangeReadRequest>(requests[0])
223-
.promise.SetResult(std::move(result));
224262
}
225263
template <typename Requests>
226264
void SetCommonResult(const Requests& requests,
@@ -231,11 +269,10 @@ void SetCommonResult(const Requests& requests,
231269

232270
template <typename Request>
233271
void SortRequestsByStartByte(span<Request> requests) {
234-
std::sort(
235-
requests.begin(), requests.end(), [](const Request& a, const Request& b) {
236-
return std::get<ByteRangeReadRequest>(a).byte_range.inclusive_min <
237-
std::get<ByteRangeReadRequest>(b).byte_range.inclusive_min;
238-
});
272+
std::sort(requests.begin(), requests.end(),
273+
[](const Request& a, const Request& b) {
274+
return a.byte_range.inclusive_min < b.byte_range.inclusive_min;
275+
});
239276
}
240277

241278
// Resolves coalesced requests with the appropriate cord subranges.
@@ -244,19 +281,18 @@ void ResolveCoalescedRequests(ByteRange coalesced_byte_range,
244281
span<Request> coalesced_requests,
245282
kvstore::ReadResult&& read_result) {
246283
for (auto& request : coalesced_requests) {
247-
auto& byte_range_request = std::get<ByteRangeReadRequest>(request);
248284
kvstore::ReadResult sub_read_result;
249285
sub_read_result.stamp = read_result.stamp;
250286
sub_read_result.state = read_result.state;
251287
if (read_result.state == kvstore::ReadResult::kValue) {
252288
ABSL_DCHECK_EQ(coalesced_byte_range.size(), read_result.value.size());
253-
int64_t request_start = byte_range_request.byte_range.inclusive_min -
254-
coalesced_byte_range.inclusive_min;
255-
int64_t request_size = byte_range_request.byte_range.size();
289+
int64_t request_start =
290+
request.byte_range.inclusive_min - coalesced_byte_range.inclusive_min;
291+
int64_t request_size = request.byte_range.size();
256292
sub_read_result.value =
257293
read_result.value.Subcord(request_start, request_size);
258294
}
259-
byte_range_request.promise.SetResult(std::move(sub_read_result));
295+
request.promise.SetResult(std::move(sub_read_result));
260296
}
261297
}
262298

@@ -277,19 +313,17 @@ void ResolveCoalescedRequests(ByteRange coalesced_byte_range,
277313
template <typename Request, typename Predicate, typename Callback>
278314
void ForEachCoalescedRequest(span<Request> requests, Predicate predicate,
279315
Callback callback) {
316+
static_assert(IsByteRangeReadRequestLikeV<Request>);
317+
280318
SortRequestsByStartByte(requests);
281319

282320
size_t request_i = 0;
283321
while (request_i < requests.size()) {
284-
auto coalesced_byte_range =
285-
std::get<ByteRangeReadRequest>(requests[request_i])
286-
.byte_range.AsByteRange();
322+
auto coalesced_byte_range = requests[request_i].byte_range.AsByteRange();
287323
size_t end_request_i;
288324
for (end_request_i = request_i + 1; end_request_i < requests.size();
289325
++end_request_i) {
290-
auto next_byte_range =
291-
std::get<ByteRangeReadRequest>(requests[end_request_i])
292-
.byte_range.AsByteRange();
326+
auto next_byte_range = requests[end_request_i].byte_range.AsByteRange();
293327
if (next_byte_range.inclusive_min < coalesced_byte_range.exclusive_max ||
294328
predicate(coalesced_byte_range, next_byte_range.inclusive_min)) {
295329
coalesced_byte_range.exclusive_max = std::max(
@@ -314,12 +348,9 @@ void ForEachCoalescedRequest(span<Request> requests, Predicate predicate,
314348
template <typename Request>
315349
bool ValidateRequestGeneration(Request& request,
316350
const TimestampedStorageGeneration& stamp) {
317-
auto& byte_range_request = std::get<ByteRangeReadRequest>(request);
318-
if (!byte_range_request.promise.result_needed()) return false;
319-
if (!std::get<kvstore::ReadGenerationConditions>(request).Matches(
320-
stamp.generation)) {
321-
byte_range_request.promise.SetResult(
322-
kvstore::ReadResult::Unspecified(stamp));
351+
if (!request.promise.result_needed()) return false;
352+
if (!request.generation_conditions.Matches(stamp.generation)) {
353+
request.promise.SetResult(kvstore::ReadResult::Unspecified(stamp));
323354
return false;
324355
}
325356
return true;
@@ -334,14 +365,15 @@ bool ValidateRequestGeneration(Request& request,
334365
template <typename Request>
335366
bool ValidateRequestGenerationAndByteRange(
336367
Request& request, const TimestampedStorageGeneration& stamp, int64_t size) {
368+
static_assert(IsByteRangeReadRequestLikeV<Request>);
369+
static_assert(
370+
std::is_member_pointer<decltype(&Request::generation_conditions)>::value);
337371
if (!ValidateRequestGeneration(request, stamp)) {
338372
return false;
339373
}
340-
auto& byte_range_request = std::get<ByteRangeReadRequest>(request);
341374
TENSORSTORE_ASSIGN_OR_RETURN(
342-
byte_range_request.byte_range,
343-
byte_range_request.byte_range.Validate(size),
344-
(byte_range_request.promise.SetResult(std::move(_)), false));
375+
request.byte_range, request.byte_range.Validate(size),
376+
(request.promise.SetResult(std::move(_)), false));
345377
return true;
346378
}
347379

tensorstore/kvstore/file/file_key_value_store.cc

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -364,7 +364,7 @@ Result<UniqueFileDescriptor> OpenValueFile(const std::string& path,
364364
class BatchReadTask;
365365
using BatchReadTaskBase = internal_kvstore_batch::BatchReadEntry<
366366
FileKeyValueStore,
367-
internal_kvstore_batch::ReadRequest<kvstore::ReadGenerationConditions>,
367+
/*ReadRequest=*/internal_kvstore_batch::ByteRangeGenerationReadRequest,
368368
// BatchEntryKey members:
369369
std::string /* file_path*/>;
370370

@@ -448,12 +448,9 @@ class BatchReadTask final
448448
}
449449

450450
if (requests.size() == 1) {
451-
auto& byte_range_request =
452-
std::get<internal_kvstore_batch::ByteRangeReadRequest>(requests[0]);
453451
// Perform single read immediately.
454-
byte_range_request.promise.SetResult(
455-
DoByteRangeRead(byte_range_request.byte_range.AsByteRange()));
456-
452+
requests[0].promise.SetResult(
453+
DoByteRangeRead(requests[0].byte_range.AsByteRange()));
457454
return;
458455
}
459456

@@ -480,9 +477,7 @@ class BatchReadTask final
480477
int64_t inclusive_min = std::numeric_limits<int64_t>::max();
481478
int64_t total_size = 0;
482479
for (const auto& req : requests) {
483-
const auto byte_range =
484-
std::get<internal_kvstore_batch::ByteRangeReadRequest>(req)
485-
.byte_range.AsByteRange();
480+
const auto byte_range = req.byte_range.AsByteRange();
486481
inclusive_min = std::min(inclusive_min, byte_range.inclusive_min);
487482
exclusive_max = std::max(exclusive_max, byte_range.exclusive_max);
488483
total_size += byte_range.size();
@@ -503,13 +498,11 @@ class BatchReadTask final
503498
} else if (mapped_result.ok()) {
504499
absl::Cord file_contents = std::move(mapped_result).value().as_cord();
505500
for (const auto& req : requests) {
506-
auto& byte_range_request =
507-
std::get<internal_kvstore_batch::ByteRangeReadRequest>(req);
508-
ByteRange byte_range = byte_range_request.byte_range.AsByteRange();
501+
ByteRange byte_range = req.byte_range.AsByteRange();
509502
assert(byte_range.inclusive_min >= inclusive_min);
510503
absl::Cord subcord = file_contents.Subcord(
511504
byte_range.inclusive_min - inclusive_min, byte_range.size());
512-
byte_range_request.promise.SetResult(
505+
req.promise.SetResult(
513506
kvstore::ReadResult::Value(std::move(subcord), stamp_));
514507
}
515508
return true;
@@ -530,9 +523,7 @@ class BatchReadTask final
530523
// Determine if it's possible to read entire blocks.
531524
int64_t exclusive_max = 0;
532525
for (const auto& req : requests) {
533-
const auto byte_range =
534-
std::get<internal_kvstore_batch::ByteRangeReadRequest>(req)
535-
.byte_range.AsByteRange();
526+
const auto byte_range = req.byte_range.AsByteRange();
536527
exclusive_max = std::max(exclusive_max, byte_range.exclusive_max);
537528
}
538529
exclusive_max = RoundUpTo(exclusive_max, block_alignment);
@@ -565,7 +556,7 @@ Future<ReadResult> FileKeyValueStore::Read(Key key, ReadOptions options) {
565556
auto [promise, future] = PromiseFuturePair<kvstore::ReadResult>::Make();
566557
BatchReadTask::MakeRequest<BatchReadTask>(
567558
*this, {std::move(key)}, options.batch, options.staleness_bound,
568-
BatchReadTask::Request{{std::move(promise), options.byte_range},
559+
BatchReadTask::Request{std::move(promise), options.byte_range,
569560
std::move(options.generation_conditions)});
570561
return std::move(future);
571562
}

tensorstore/kvstore/generic_coalescing_batch_util.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ namespace internal_kvstore_batch {
3636

3737
template <typename DerivedDriver>
3838
using GenericCoalescingBatchReadEntryBase =
39-
BatchReadEntry<DerivedDriver, ReadRequest<>,
39+
BatchReadEntry<DerivedDriver, /*ReadRequest=*/ByteRangeReadRequest,
4040
// BatchEntryKey members:
4141
kvstore::Key, kvstore::ReadGenerationConditions>;
4242

@@ -135,7 +135,7 @@ Future<kvstore::ReadResult> HandleBatchRequestByGenericByteRangeCoalescing(
135135
Entry::template MakeRequest<Entry>(
136136
driver, std::move(key), std::move(options.generation_conditions),
137137
options.batch, options.staleness_bound,
138-
typename Entry::Request{{std::move(promise), options.byte_range}});
138+
typename Entry::Request{std::move(promise), options.byte_range});
139139
return std::move(future);
140140
}
141141

tensorstore/kvstore/mock_kvstore.cc

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,9 @@ void MockKeyValueStore::BatchReadRequest::operator()(
8989
kvstore::ReadOptions options;
9090
options.staleness_bound = request_batch.staleness_bound;
9191
options.batch = batch;
92-
options.generation_conditions =
93-
std::get<kvstore::ReadGenerationConditions>(request);
94-
auto& byte_range_request =
95-
std::get<internal_kvstore_batch::ByteRangeReadRequest>(request);
96-
options.byte_range = byte_range_request.byte_range;
97-
LinkResult(byte_range_request.promise,
98-
target->Read(key, std::move(options)));
92+
options.generation_conditions = request.generation_conditions;
93+
options.byte_range = request.byte_range;
94+
LinkResult(request.promise, target->Read(key, std::move(options)));
9995
}
10096
}
10197

@@ -104,7 +100,9 @@ Future<kvstore::ReadResult> MockKeyValueStore::Read(Key key,
104100
if (handle_batch_requests && options.batch) {
105101
class BatchEntry;
106102
using BatchEntryBase = internal_kvstore_batch::BatchReadEntry<
107-
MockKeyValueStore, BatchReadRequest::Request, kvstore::Key>;
103+
MockKeyValueStore, /*ReadRequest=*/BatchReadRequest::Request,
104+
// BatchEntryKey members:
105+
kvstore::Key>;
108106
class BatchEntry : public BatchEntryBase {
109107
public:
110108
using BatchEntryBase::BatchEntryBase;
@@ -120,13 +118,9 @@ Future<kvstore::ReadResult> MockKeyValueStore::Read(Key key,
120118
::nlohmann::json::array_t requests_log;
121119
for (const auto& request : request_batch.requests) {
122120
::nlohmann::json::object_t request_log;
123-
AddByteRangeToLogEntry(
124-
request_log,
125-
std::get<internal_kvstore_batch::ByteRangeReadRequest>(request)
126-
.byte_range);
127-
AddGenerationConditionsToLogEntry(
128-
request_log,
129-
std::get<kvstore::ReadGenerationConditions>(request));
121+
AddByteRangeToLogEntry(request_log, request.byte_range);
122+
AddGenerationConditionsToLogEntry(request_log,
123+
request.generation_conditions);
130124
requests_log.push_back(std::move(request_log));
131125
}
132126
log_entry.emplace("requests", std::move(requests_log));
@@ -147,7 +141,7 @@ Future<kvstore::ReadResult> MockKeyValueStore::Read(Key key,
147141
auto [promise, future] = PromiseFuturePair<kvstore::ReadResult>::Make();
148142
BatchEntry::MakeRequest<BatchEntry>(
149143
*this, std::move(key), options.batch, options.staleness_bound,
150-
BatchEntry::Request{{std::move(promise), options.byte_range},
144+
BatchEntry::Request{std::move(promise), options.byte_range,
151145
std::move(options.generation_conditions)});
152146
return std::move(future);
153147
}

tensorstore/kvstore/mock_kvstore.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
#include "tensorstore/kvstore/driver.h"
2727
#include "tensorstore/kvstore/generation.h"
2828
#include "tensorstore/kvstore/key_range.h"
29-
#include "tensorstore/kvstore/operations.h"
3029
#include "tensorstore/kvstore/spec.h"
3130
#include "tensorstore/kvstore/supported_features.h"
3231
#include "tensorstore/util/future.h"
@@ -58,9 +57,9 @@ class MockKeyValueStore : public kvstore::Driver {
5857
};
5958

6059
struct BatchReadRequest {
60+
using Request = internal_kvstore_batch::ByteRangeGenerationReadRequest;
61+
6162
Key key;
62-
using Request =
63-
internal_kvstore_batch::ReadRequest<kvstore::ReadGenerationConditions>;
6463
using RequestBatch = internal_kvstore_batch::RequestBatch<Request>;
6564
RequestBatch request_batch;
6665
void operator()(kvstore::DriverPtr target) const;

0 commit comments

Comments
 (0)