Skip to content

Commit d925a22

Browse files
committed
thread_safe_memoizer.h reword
1 parent f5ae906 commit d925a22

File tree

9 files changed

+111
-108
lines changed

9 files changed

+111
-108
lines changed

Firestore/core/src/core/composite_filter.cc

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,13 @@ const FieldFilter* CompositeFilter::Rep::FindFirstMatchingFilter(
142142
return nullptr;
143143
}
144144

145-
const std::vector<FieldFilter>& CompositeFilter::Rep::GetFlattenedFilters()
146-
const {
147-
return memoized_flattened_filters_.memoize([&]() {
148-
auto flattened_filters = absl::make_unique<std::vector<FieldFilter>>();
149-
for (const auto& filter : filters())
150-
std::copy(filter.GetFlattenedFilters().begin(),
151-
filter.GetFlattenedFilters().end(),
152-
std::back_inserter(*flattened_filters));
153-
return flattened_filters;
154-
});
145+
std::shared_ptr<std::vector<FieldFilter>> CompositeFilter::Rep::CalculateFlattenedFilters() const {
146+
auto flattened_filters = absl::make_unique<std::vector<FieldFilter>>();
147+
for (const auto& filter : filters())
148+
std::copy(filter.GetFlattenedFilters().begin(),
149+
filter.GetFlattenedFilters().end(),
150+
std::back_inserter(*flattened_filters));
151+
return flattened_filters;
155152
}
156153

157154
} // namespace core

Firestore/core/src/core/composite_filter.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ class CompositeFilter : public Filter {
138138
return filters_.empty();
139139
}
140140

141-
const std::vector<FieldFilter>& GetFlattenedFilters() const override;
141+
std::shared_ptr<std::vector<FieldFilter>> CalculateFlattenedFilters() const override;
142142

143143
std::vector<Filter> GetFilters() const override {
144144
return filters();

Firestore/core/src/core/field_filter.cc

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,13 +123,11 @@ FieldFilter::FieldFilter(std::shared_ptr<const Filter::Rep> rep)
123123
: Filter(std::move(rep)) {
124124
}
125125

126-
const std::vector<FieldFilter>& FieldFilter::Rep::GetFlattenedFilters() const {
126+
std::shared_ptr<std::vector<FieldFilter>> FieldFilter::Rep::CalculateFlattenedFilters() const {
127127
// This is already a field filter, so we return a vector of size one.
128-
return memoized_flattened_filters_.memoize([&]() {
129-
auto filters = absl::make_unique<std::vector<FieldFilter>>();
130-
filters->push_back(FieldFilter(std::make_shared<const Rep>(*this)));
131-
return filters;
132-
});
128+
auto filters = std::make_shared<std::vector<FieldFilter>>();
129+
filters->push_back(FieldFilter(std::make_shared<const Rep>(*this)));
130+
return filters;
133131
}
134132

135133
std::vector<Filter> FieldFilter::Rep::GetFilters() const {

Firestore/core/src/core/field_filter.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,6 @@ class FieldFilter : public Filter {
117117
return false;
118118
}
119119

120-
const std::vector<FieldFilter>& GetFlattenedFilters() const override;
121-
122120
std::vector<Filter> GetFilters() const override;
123121

124122
protected:
@@ -140,6 +138,8 @@ class FieldFilter : public Filter {
140138

141139
bool MatchesComparison(util::ComparisonResult comparison) const;
142140

141+
std::shared_ptr<std::vector<FieldFilter>> CalculateFlattenedFilters() const override;
142+
143143
private:
144144
friend class FieldFilter;
145145

Firestore/core/src/core/filter.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -145,21 +145,21 @@ class Filter {
145145

146146
virtual bool IsEmpty() const = 0;
147147

148-
virtual const std::vector<FieldFilter>& GetFlattenedFilters() const = 0;
148+
virtual const std::vector<FieldFilter>& GetFlattenedFilters() const {
149+
return flattened_filters_.value();
150+
}
149151

150152
virtual std::vector<Filter> GetFilters() const = 0;
151153

154+
protected:
155+
virtual std::shared_ptr<std::vector<FieldFilter>> CalculateFlattenedFilters() const = 0;
156+
157+
private:
152158
/**
153159
* Memoized list of all field filters that can be found by
154160
* traversing the tree of filters contained in this composite filter.
155-
*
156-
* Use a `std::shared_ptr<ThreadSafeMemoizer>` rather than using
157-
* `ThreadSafeMemoizer` directly so that this class is copyable
158-
* (`ThreadSafeMemoizer` is not copyable because of its `std::once_flag`
159-
* member variable, which is not copyable).
160161
*/
161-
mutable util::ThreadSafeMemoizer<std::vector<FieldFilter>>
162-
memoized_flattened_filters_;
162+
mutable util::ThreadSafeMemoizer<std::vector<FieldFilter>> flattened_filters_{[&] { return CalculateFlattenedFilters(); }};
163163
};
164164

165165
explicit Filter(std::shared_ptr<const Rep>&& rep) : rep_(rep) {

Firestore/core/src/core/query.cc

Lines changed: 37 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -92,44 +92,42 @@ absl::optional<Operator> Query::FindOpInsideFilters(
9292
return absl::nullopt;
9393
}
9494

95-
const std::vector<OrderBy>& Query::normalized_order_bys() const {
96-
return memoized_normalized_order_bys_.memoize([&]() {
97-
// Any explicit order by fields should be added as is.
98-
auto result = absl::make_unique<std::vector<OrderBy>>(explicit_order_bys_);
99-
std::set<FieldPath> fieldsNormalized;
100-
for (const OrderBy& order_by : explicit_order_bys_) {
101-
fieldsNormalized.insert(order_by.field());
102-
}
95+
std::shared_ptr<std::vector<OrderBy>> Query::calculate_normalized_order_bys() const {
96+
auto result = std::make_shared<std::vector<OrderBy>>(explicit_order_bys_);
10397

104-
// The order of the implicit ordering always matches the last explicit order
105-
// by.
106-
Direction last_direction = explicit_order_bys_.empty()
107-
? Direction::Ascending
108-
: explicit_order_bys_.back().direction();
109-
110-
// Any inequality fields not explicitly ordered should be implicitly ordered
111-
// in a lexicographical order. When there are multiple inequality filters on
112-
// the same field, the field should be added only once. Note:
113-
// `std::set<model::FieldPath>` sorts the key field before other fields.
114-
// However, we want the key field to be sorted last.
115-
const std::set<model::FieldPath> inequality_fields =
116-
InequalityFilterFields();
117-
118-
for (const model::FieldPath& field : inequality_fields) {
119-
if (fieldsNormalized.find(field) == fieldsNormalized.end() &&
120-
!field.IsKeyFieldPath()) {
121-
result->push_back(OrderBy(field, last_direction));
122-
}
123-
}
98+
std::set<FieldPath> fieldsNormalized;
99+
for (const OrderBy& order_by : explicit_order_bys_) {
100+
fieldsNormalized.insert(order_by.field());
101+
}
124102

125-
// Add the document key field to the last if it is not explicitly ordered.
126-
if (fieldsNormalized.find(FieldPath::KeyFieldPath()) ==
127-
fieldsNormalized.end()) {
128-
result->push_back(OrderBy(FieldPath::KeyFieldPath(), last_direction));
129-
}
103+
// The order of the implicit ordering always matches the last explicit order
104+
// by.
105+
Direction last_direction = explicit_order_bys_.empty()
106+
? Direction::Ascending
107+
: explicit_order_bys_.back().direction();
108+
109+
// Any inequality fields not explicitly ordered should be implicitly ordered
110+
// in a lexicographical order. When there are multiple inequality filters on
111+
// the same field, the field should be added only once. Note:
112+
// `std::set<model::FieldPath>` sorts the key field before other fields.
113+
// However, we want the key field to be sorted last.
114+
const std::set<model::FieldPath> inequality_fields =
115+
InequalityFilterFields();
116+
117+
for (const model::FieldPath& field : inequality_fields) {
118+
if (fieldsNormalized.find(field) == fieldsNormalized.end() &&
119+
!field.IsKeyFieldPath()) {
120+
result->push_back(OrderBy(field, last_direction));
121+
}
122+
}
130123

131-
return result;
132-
});
124+
// Add the document key field to the last if it is not explicitly ordered.
125+
if (fieldsNormalized.find(FieldPath::KeyFieldPath()) ==
126+
fieldsNormalized.end()) {
127+
result->push_back(OrderBy(FieldPath::KeyFieldPath(), last_direction));
128+
}
129+
130+
return result;
133131
}
134132

135133
LimitType Query::limit_type() const {
@@ -297,16 +295,12 @@ std::string Query::ToString() const {
297295
return absl::StrCat("Query(canonical_id=", CanonicalId(), ")");
298296
}
299297

300-
const Target& Query::ToTarget() const& {
301-
return memoized_target_.memoize([&]() {
302-
return absl::make_unique<Target>(ToTarget(normalized_order_bys()));
303-
});
298+
std::shared_ptr<Target> Query::calculate_target() const {
299+
return std::make_shared<Target>(ToTarget(normalized_order_bys()));
304300
}
305301

306-
const Target& Query::ToAggregateTarget() const& {
307-
return memoized_aggregate_target_.memoize([&]() {
308-
return absl::make_unique<Target>(ToTarget(explicit_order_bys_));
309-
});
302+
std::shared_ptr<Target> Query::calculate_aggregate_target() const {
303+
return std::make_shared<Target>(ToTarget(explicit_order_bys_));
310304
}
311305

312306
Target Query::ToTarget(const std::vector<OrderBy>& order_bys) const {

Firestore/core/src/core/query.h

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,9 @@ class Query {
148148
* This might include additional sort orders added implicitly to match the
149149
* backend behavior.
150150
*/
151-
const std::vector<OrderBy>& normalized_order_bys() const;
151+
const std::vector<OrderBy>& normalized_order_bys() const {
152+
return normalized_order_bys_.value();
153+
}
152154

153155
bool has_limit() const {
154156
return limit_ != Target::kNoLimit;
@@ -246,15 +248,19 @@ class Query {
246248
* Returns a `Target` instance this query will be mapped to in backend
247249
* and local store.
248250
*/
249-
const Target& ToTarget() const&;
251+
const Target& ToTarget() const& {
252+
return target_.value();
253+
}
250254

251255
/**
252256
* Returns a `Target` instance this query will be mapped to in backend
253257
* and local store, for use within an aggregate query. Unlike targets
254258
* for non-aggregate queries, aggregate query targets do not contain
255259
* normalized order-bys, they only contain explicit order-bys.
256260
*/
257-
const Target& ToAggregateTarget() const&;
261+
const Target& ToAggregateTarget() const& {
262+
return aggregate_target_.value();
263+
}
258264

259265
friend std::ostream& operator<<(std::ostream& os, const Query& query);
260266

@@ -295,16 +301,19 @@ class Query {
295301
// member variable, which is not copyable).
296302

297303
// The memoized list of sort orders.
304+
std::shared_ptr<std::vector<OrderBy>> calculate_normalized_order_bys() const;
298305
mutable util::ThreadSafeMemoizer<std::vector<OrderBy>>
299-
memoized_normalized_order_bys_;
306+
normalized_order_bys_{[&] {return calculate_normalized_order_bys(); }};
300307

301308
// The corresponding Target of this Query instance.
302-
mutable util::ThreadSafeMemoizer<Target> memoized_target_;
309+
std::shared_ptr<Target> calculate_target() const;
310+
mutable util::ThreadSafeMemoizer<Target> target_{[&] {return calculate_target(); }};
303311

304312
// The corresponding aggregate Target of this Query instance. Unlike targets
305313
// for non-aggregate queries, aggregate query targets do not contain
306314
// normalized order-bys, they only contain explicit order-bys.
307-
mutable util::ThreadSafeMemoizer<Target> memoized_aggregate_target_;
315+
std::shared_ptr<Target> calculate_aggregate_target() const;
316+
mutable util::ThreadSafeMemoizer<Target> aggregate_target_{[&] {return calculate_aggregate_target(); }};;
308317
};
309318

310319
bool operator==(const Query& lhs, const Query& rhs);

Firestore/core/src/util/thread_safe_memoizer.h

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <atomic>
2121
#include <functional>
2222
#include <memory>
23+
#include <utility>
2324

2425
namespace firebase {
2526
namespace firestore {
@@ -32,14 +33,25 @@ namespace util {
3233
template <typename T>
3334
class ThreadSafeMemoizer {
3435
public:
35-
ThreadSafeMemoizer()
36-
: memoized_(new std::atomic<T*>(nullptr), MemoizedValueDeleter) {
36+
explicit ThreadSafeMemoizer(std::function<std::shared_ptr<T>()> func) : func_(std::move(func)) {
3737
}
3838

39-
ThreadSafeMemoizer(const ThreadSafeMemoizer& other) = default;
40-
ThreadSafeMemoizer& operator=(const ThreadSafeMemoizer& other) = default;
41-
ThreadSafeMemoizer(ThreadSafeMemoizer&& other) = default;
42-
ThreadSafeMemoizer& operator=(ThreadSafeMemoizer&& other) = default;
39+
ThreadSafeMemoizer(const ThreadSafeMemoizer& other) : func_(other.func_), memoized_(std::atomic_load(&other.memoized_)) {}
40+
41+
ThreadSafeMemoizer& operator=(const ThreadSafeMemoizer& other) {
42+
func_ = other.func_;
43+
std::atomic_store(&memoized_, std::atomic_load(&other.memoized_));
44+
return *this;
45+
}
46+
47+
ThreadSafeMemoizer(ThreadSafeMemoizer&& other) noexcept : func_(std::move(other.func_)), memoized_(std::atomic_load(&other.memoized_)) {
48+
}
49+
50+
ThreadSafeMemoizer& operator=(ThreadSafeMemoizer&& other) noexcept {
51+
func_ = std::move(other.func_);
52+
std::atomic_store(&memoized_, std::atomic_load(&other.memoized_));
53+
return *this;
54+
}
4355

4456
/**
4557
* Memoize a value.
@@ -58,28 +70,24 @@ class ThreadSafeMemoizer {
5870
* No reference to the given function is retained by this object, and the
5971
* function be called synchronously, if it is called at all.
6072
*/
61-
const T& memoize(std::function<std::unique_ptr<T>()> func) {
73+
const T& value() {
74+
std::shared_ptr<T> old_memoized = std::atomic_load(&memoized_);
6275
while (true) {
63-
T* old_memoized = memoized_->load();
6476
if (old_memoized) {
6577
return *old_memoized;
6678
}
6779

68-
std::unique_ptr<T> new_memoized = func();
80+
std::shared_ptr<T> new_memoized = func_();
6981

70-
if (memoized_->compare_exchange_weak(old_memoized, new_memoized.get())) {
71-
return *new_memoized.release();
82+
if (std::atomic_compare_exchange_weak(&memoized_, &old_memoized, new_memoized)) {
83+
return *new_memoized;
7284
}
7385
}
7486
}
7587

7688
private:
77-
std::shared_ptr<std::atomic<T*>> memoized_;
78-
79-
static void MemoizedValueDeleter(std::atomic<T*>* value) {
80-
delete value->load();
81-
delete value;
82-
}
89+
std::function<std::shared_ptr<T>()> func_;
90+
std::shared_ptr<T> memoized_;
8391
};
8492

8593
} // namespace util

Firestore/core/test/unit/util/thread_safe_memoizer_test.cc

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -20,34 +20,33 @@
2020
#include "absl/memory/memory.h"
2121
#include "gtest/gtest.h"
2222

23-
namespace firebase {
24-
namespace firestore {
25-
namespace util {
23+
namespace {
24+
25+
using firebase::firestore::util::ThreadSafeMemoizer;
2626

2727
TEST(ThreadSafeMemoizerTest, MultiThreadedMemoization) {
2828
std::atomic<int> global_int{77};
2929

30-
auto expensive_lambda = [&]() {
30+
auto expensive_lambda = [&] {
3131
// Simulate an expensive operation
3232
std::this_thread::sleep_for(std::chrono::milliseconds(100));
3333
// If the lambda gets executed multiple times, threads will see incremented
3434
// `global_int`.
35-
global_int++;
36-
return absl::make_unique<int>(global_int.load());
35+
++global_int;
36+
return std::make_shared<int>(global_int.load());
3737
};
3838

39-
const int num_threads = 5;
40-
const int expected_result = 78;
39+
constexpr int num_threads = 5;
40+
constexpr int expected_result = 78;
4141

4242
// Create a thread safe memoizer and multiple threads.
43-
util::ThreadSafeMemoizer<int> memoized_result;
43+
ThreadSafeMemoizer<int> memoized_result(expensive_lambda);
4444
std::vector<std::thread> threads;
4545

4646
for (int i = 0; i < num_threads; ++i) {
4747
threads.emplace_back(
48-
[&memoized_result, expected_result, &expensive_lambda]() {
49-
const int& actual_result = memoized_result.memoize(expensive_lambda);
50-
48+
[&memoized_result, expected_result] {
49+
const int& actual_result = memoized_result.value();
5150
// Verify that all threads get the same memoized result.
5251
EXPECT_EQ(actual_result, expected_result);
5352
});
@@ -58,6 +57,4 @@ TEST(ThreadSafeMemoizerTest, MultiThreadedMemoization) {
5857
}
5958
}
6059

61-
} // namespace util
62-
} // namespace firestore
63-
} // namespace firebase
60+
} // namespace

0 commit comments

Comments
 (0)