Skip to content

Commit 0b4d194

Browse files
amabluea-maurice
authored andcommitted
Hooked up the PruneForest so that cached values are appropriately pruned.
PiperOrigin-RevId: 285262340
1 parent ce8d881 commit 0b4d194

12 files changed

+280
-27
lines changed

database/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ set(desktop_SRCS
7272
src/desktop/connection/persistent_connection.cc
7373
src/desktop/connection/util_connection.cc
7474
src/desktop/connection/web_socket_client_impl.cc
75+
src/desktop/core/cache_policy.cc
7576
src/desktop/core/child_event_registration.cc
7677
src/desktop/core/compound_write.cc
7778
src/desktop/core/constants.cc
@@ -98,6 +99,7 @@ set(desktop_SRCS
9899
src/desktop/persistence/in_memory_persistence_storage_engine.cc
99100
src/desktop/persistence/noop_persistence_storage_engine.cc
100101
src/desktop/persistence/persistence_manager.cc
102+
src/desktop/persistence/prune_forest.cc
101103
src/desktop/push_child_name_generator.cc
102104
src/desktop/query_desktop.cc
103105
src/desktop/query_params_comparator.cc

database/src/desktop/core/repo.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -374,8 +374,7 @@ void Repo::UpdateChildren(const Path& path, const Variant& data,
374374

375375
// Start with our existing data and merge each child into it.
376376
Variant server_values = GenerateServerValues(server_time_offset_);
377-
CompoundWrite resolved =
378-
ResolveDeferredValueMerge(updates, server_values);
377+
CompoundWrite resolved = ResolveDeferredValueMerge(updates, server_values);
379378

380379
WriteId write_id = GetNextWriteId();
381380
std::vector<Event> events = server_sync_tree_->ApplyUserMerge(
@@ -425,14 +424,19 @@ void Repo::AckWriteAndRerunTransactions(WriteId write_id, const Path& path,
425424

426425
static UniquePtr<SyncTree> InitializeSyncTree(
427426
UniquePtr<ListenProvider> listen_provider, Logger* logger) {
427+
static const uint64_t kDefaultCacheSize = 10 * 1024 * 1024;
428+
428429
UniquePtr<WriteTree> pending_write_tree = MakeUnique<WriteTree>();
429430
UniquePtr<PersistenceStorageEngine> persistence_storage_engine =
430431
MakeUnique<InMemoryPersistenceStorageEngine>(logger);
431432
UniquePtr<TrackedQueryManager> tracked_query_manager =
432-
MakeUnique<TrackedQueryManager>(persistence_storage_engine.get());
433+
MakeUnique<TrackedQueryManager>(persistence_storage_engine.get(), logger);
434+
UniquePtr<CachePolicy> cache_policy =
435+
MakeUnique<LRUCachePolicy>(kDefaultCacheSize);
433436
UniquePtr<PersistenceManager> persistence_manager =
434437
MakeUnique<PersistenceManager>(std::move(persistence_storage_engine),
435-
std::move(tracked_query_manager));
438+
std::move(tracked_query_manager),
439+
std::move(cache_policy), logger);
436440
return MakeUnique<SyncTree>(std::move(pending_write_tree),
437441
std::move(persistence_manager),
438442
std::move(listen_provider));

database/src/desktop/core/tracked_query_manager.cc

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,14 @@
1616

1717
#include <algorithm>
1818
#include <cstdint>
19+
#include <ctime>
1920
#include <map>
2021

2122
#include "app/src/assert.h"
2223
#include "app/src/path.h"
2324
#include "database/src/common/query_spec.h"
2425
#include "database/src/desktop/persistence/persistence_storage_engine.h"
26+
#include "database/src/desktop/persistence/prune_forest.h"
2527
#include "database/src/desktop/util_desktop.h"
2628

2729
namespace firebase {
@@ -51,7 +53,7 @@ static bool IsQueryPrunablePredicate(const TrackedQuery& query) {
5153

5254
// Returns true if the given TrackedQuery is not prunable. A query is considered
5355
// prunable if it is not active.
54-
static bool IsQueryUnPrunablePredicate(const TrackedQuery& query) {
56+
static bool IsQueryUnprunablePredicate(const TrackedQuery& query) {
5557
return query.active;
5658
}
5759

@@ -69,10 +71,11 @@ static void AssertValidTrackedQuery(const QuerySpec& query_spec) {
6971
}
7072

7173
TrackedQueryManager::TrackedQueryManager(
72-
PersistenceStorageEngine* storage_engine)
74+
PersistenceStorageEngine* storage_engine, LoggerBase* logger)
7375
: storage_engine_(storage_engine),
7476
tracked_query_tree_(),
75-
next_query_id_(0) {
77+
next_query_id_(0),
78+
logger_(logger) {
7679
ResetPreviouslyActiveTrackedQueries();
7780

7881
// Populate our cache from the storage layer.
@@ -84,6 +87,8 @@ TrackedQueryManager::TrackedQueryManager(
8487
}
8588
}
8689

90+
TrackedQueryManager::~TrackedQueryManager() {}
91+
8792
const TrackedQuery* TrackedQueryManager::FindTrackedQuery(
8893
const QuerySpec& query_spec) const {
8994
QuerySpec normalized_spec = GetNormalizedQuery(query_spec);
@@ -114,7 +119,9 @@ void TrackedQueryManager::SetQueryActiveFlag(
114119
QuerySpec normalized_spec = GetNormalizedQuery(query_spec);
115120
const TrackedQuery* tracked_query = FindTrackedQuery(normalized_spec);
116121

117-
uint64_t last_use = 0;
122+
// TODO(amablue): Set up a more robust clock that won't get confused if, for
123+
// example, the system time changes while the app is running.
124+
uint64_t last_use = time(nullptr) * 1000;
118125
if (tracked_query != nullptr) {
119126
TrackedQuery updated_tracked_query = *tracked_query;
120127
updated_tracked_query.last_use = last_use;
@@ -173,6 +180,65 @@ bool TrackedQueryManager::IsQueryComplete(const QuerySpec& query_spec) {
173180
}
174181
}
175182

183+
static uint64_t CalculateCountToPrune(const CachePolicy& cache_policy,
184+
uint64_t prunable_count) {
185+
uint64_t count_to_keep = prunable_count;
186+
187+
// Prune by percentage.
188+
double percent_to_keep =
189+
1.0 - cache_policy.GetPercentOfQueriesToPruneAtOnce();
190+
count_to_keep = static_cast<uint64_t>(count_to_keep * percent_to_keep);
191+
192+
// Make sure we're not keeping more than the max.
193+
count_to_keep =
194+
std::min(count_to_keep, cache_policy.GetMaxNumberOfQueriesToKeep());
195+
196+
// Now we know how many to prune.
197+
return prunable_count - count_to_keep;
198+
}
199+
200+
PruneForest TrackedQueryManager::PruneOldQueries(
201+
const CachePolicy& cache_policy) {
202+
std::vector<TrackedQuery> prunable =
203+
GetQueriesMatching(IsQueryPrunablePredicate);
204+
uint64_t count_to_prune =
205+
CalculateCountToPrune(cache_policy, prunable.size());
206+
207+
logger_->LogDebug("Pruning old queries. Prunable: %i Count to prune: %i",
208+
static_cast<int>(prunable.size()),
209+
static_cast<int>(count_to_prune));
210+
211+
std::sort(prunable.begin(), prunable.end(),
212+
[](const TrackedQuery& q1, const TrackedQuery& q2) {
213+
return q1.last_use < q2.last_use;
214+
});
215+
216+
// Prune the queries that are no longer needed.
217+
PruneForest forest;
218+
PruneForestRef forest_ref(&forest);
219+
for (uint64_t i = 0; i < count_to_prune; i++) {
220+
const TrackedQuery& to_prune = prunable[i];
221+
forest_ref.Prune(to_prune.query_spec.path);
222+
RemoveTrackedQuery(to_prune.query_spec);
223+
}
224+
// Keep the rest of the prunable queries.
225+
for (uint64_t i = count_to_prune; i < prunable.size(); i++) {
226+
const TrackedQuery& to_keep = prunable[i];
227+
forest_ref.Keep(to_keep.query_spec.path);
228+
}
229+
// Also keep the unprunable queries.
230+
std::vector<TrackedQuery> unprunable =
231+
GetQueriesMatching(IsQueryUnprunablePredicate);
232+
233+
logger_->LogDebug("Unprunable queries: %i",
234+
static_cast<int>(unprunable.size()));
235+
for (const TrackedQuery& to_keep : unprunable) {
236+
forest_ref.Keep(to_keep.query_spec.path);
237+
}
238+
239+
return forest;
240+
}
241+
176242
std::set<std::string> TrackedQueryManager::GetKnownCompleteChildren(
177243
const Path& path) {
178244
QuerySpec default_at_path(path);

database/src/desktop/core/tracked_query_manager.h

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,14 @@
1818
#include <cstdint>
1919
#include <map>
2020
#include <set>
21+
22+
#include "app/src/logger.h"
2123
#include "app/src/optional.h"
2224
#include "app/src/path.h"
2325
#include "database/src/common/query_spec.h"
26+
#include "database/src/desktop/core/cache_policy.h"
2427
#include "database/src/desktop/core/tree.h"
28+
#include "database/src/desktop/persistence/prune_forest.h"
2529

2630
namespace firebase {
2731
namespace database {
@@ -35,7 +39,12 @@ struct TrackedQuery {
3539
enum CompletionStatus { kIncomplete, kComplete };
3640
enum ActivityStatus { kInactive, kActive };
3741

38-
TrackedQuery() : query_id(), query_spec(), last_use(), complete(), active() {}
42+
TrackedQuery()
43+
: query_id(0),
44+
query_spec(),
45+
last_use(0),
46+
complete(false),
47+
active(false) {}
3948

4049
TrackedQuery(QueryId _query_id, const QuerySpec& _query_spec,
4150
uint64_t _last_use, CompletionStatus _complete,
@@ -91,6 +100,10 @@ class TrackedQueryManagerInterface {
91100
// complete.
92101
virtual bool IsQueryComplete(const QuerySpec& query) = 0;
93102

103+
// Remove queries that no longer need to be cached based on the given cache
104+
// policy.
105+
virtual PruneForest PruneOldQueries(const CachePolicy& cache_policy) = 0;
106+
94107
// Return the keys of the completed TrackedQueries at the given location.
95108
virtual std::set<std::string> GetKnownCompleteChildren(const Path& path) = 0;
96109

@@ -108,7 +121,10 @@ class TrackedQueryManagerInterface {
108121

109122
class TrackedQueryManager : public TrackedQueryManagerInterface {
110123
public:
111-
explicit TrackedQueryManager(PersistenceStorageEngine* storage_engine);
124+
TrackedQueryManager(PersistenceStorageEngine* storage_engine,
125+
LoggerBase* logger);
126+
127+
~TrackedQueryManager() override;
112128

113129
// Find and return the TrackedQuery associated with the given QuerySpec, or
114130
// nullptr if there is no associated TrackedQuery.
@@ -136,6 +152,10 @@ class TrackedQueryManager : public TrackedQueryManagerInterface {
136152
// complete.
137153
bool IsQueryComplete(const QuerySpec& query) override;
138154

155+
// Remove queries that no longer need to be cached based on the given cache
156+
// policy.
157+
PruneForest PruneOldQueries(const CachePolicy& cache_policy) override;
158+
139159
// Return the keys of the completed TrackedQueries at the given location.
140160
std::set<std::string> GetKnownCompleteChildren(const Path& path) override;
141161

@@ -183,6 +203,8 @@ class TrackedQueryManager : public TrackedQueryManagerInterface {
183203

184204
// ID we'll assign to the next tracked query.
185205
QueryId next_query_id_;
206+
207+
LoggerBase* logger_;
186208
};
187209

188210
} // namespace internal

database/src/desktop/persistence/in_memory_persistence_storage_engine.cc

Lines changed: 80 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
#include "app/src/assert.h"
2323
#include "app/src/include/firebase/variant.h"
2424
#include "app/src/log.h"
25+
#include "app/src/logger.h"
2526
#include "app/src/path.h"
2627
#include "app/src/variant_util.h"
2728
#include "database/src/common/query_spec.h"
@@ -36,7 +37,7 @@ namespace database {
3637
namespace internal {
3738

3839
InMemoryPersistenceStorageEngine::InMemoryPersistenceStorageEngine(
39-
Logger* logger)
40+
LoggerBase* logger)
4041
: server_cache_(), inside_transaction_(false), logger_(logger) {}
4142

4243
InMemoryPersistenceStorageEngine::~InMemoryPersistenceStorageEngine() {}
@@ -91,16 +92,68 @@ void InMemoryPersistenceStorageEngine::MergeIntoServerCache(
9192
const Path& path, const Variant& data) {
9293
VerifyInTransaction();
9394
Variant* target = MakeVariantAtPath(&server_cache_, path);
94-
if (!target->is_map()) *target = Variant::EmptyMap();
95-
PatchVariant(data, target);
95+
if (data.is_map()) {
96+
if (!target->is_map()) *target = Variant::EmptyMap();
97+
PatchVariant(data, target);
98+
} else {
99+
*target = data;
100+
}
96101
// Clean up in case anything was removed.
97102
PruneNulls(target);
98103
}
99104

100105
void InMemoryPersistenceStorageEngine::MergeIntoServerCache(
101106
const Path& path, const CompoundWrite& children) {
102-
// TODO(amablue)
103107
VerifyInTransaction();
108+
children.write_tree().CallOnEach(
109+
Path(), [this, &path](const Path& child_path, const Variant& value) {
110+
this->MergeIntoServerCache(path.GetChild(child_path), value);
111+
});
112+
}
113+
114+
static uint64_t EstimateVariantMemoryUsage(const Variant& variant) {
115+
switch (variant.type()) {
116+
case Variant::kTypeNull:
117+
case Variant::kTypeInt64:
118+
case Variant::kTypeDouble:
119+
case Variant::kTypeBool: {
120+
return sizeof(Variant);
121+
}
122+
case Variant::kTypeStaticString: {
123+
return sizeof(Variant) + strlen(variant.string_value());
124+
}
125+
case Variant::kTypeMutableString: {
126+
return sizeof(Variant) + variant.mutable_string().size();
127+
}
128+
case Variant::kTypeVector: {
129+
uint64_t sum_total = 0;
130+
for (const auto& item : variant.vector()) {
131+
sum_total += EstimateVariantMemoryUsage(item);
132+
}
133+
return sizeof(Variant) + sum_total;
134+
}
135+
case Variant::kTypeMap: {
136+
uint64_t sum_total = 0;
137+
for (const auto& key_value_pair : variant.map()) {
138+
sum_total += EstimateVariantMemoryUsage(key_value_pair.first);
139+
sum_total += EstimateVariantMemoryUsage(key_value_pair.second);
140+
}
141+
return sizeof(Variant) + sum_total;
142+
}
143+
case Variant::kTypeStaticBlob:
144+
case Variant::kTypeMutableBlob: {
145+
return sizeof(Variant) + variant.blob_size();
146+
}
147+
default: {
148+
FIREBASE_DEV_ASSERT_MESSAGE(false, "Unhandled variant type.");
149+
return 0;
150+
}
151+
}
152+
}
153+
154+
uint64_t InMemoryPersistenceStorageEngine::ServerCacheEstimatedSizeInBytes()
155+
const {
156+
return EstimateVariantMemoryUsage(server_cache_);
104157
}
105158

106159
void InMemoryPersistenceStorageEngine::SaveTrackedQuery(
@@ -139,7 +192,9 @@ void InMemoryPersistenceStorageEngine::UpdateTrackedQueryKeys(
139192

140193
std::set<std::string>& tracked_keys = tracked_query_keys_[query_id];
141194
tracked_keys.insert(added.begin(), added.end());
142-
tracked_keys.insert(removed.begin(), removed.end());
195+
for (const std::string& to_remove : removed) {
196+
tracked_keys.erase(to_remove);
197+
}
143198
}
144199

145200
std::set<std::string> InMemoryPersistenceStorageEngine::LoadTrackedQueryKeys(
@@ -157,16 +212,35 @@ std::set<std::string> InMemoryPersistenceStorageEngine::LoadTrackedQueryKeys(
157212
return result;
158213
}
159214

215+
void PruneVariant(const Path& root, const PruneForestRef prune_forest,
216+
Variant* variant) {
217+
Variant result = prune_forest.FoldKeptNodes(
218+
Variant::Null(),
219+
[&root, &variant](const Path& relative_path, Variant accum) {
220+
Variant child = VariantGetChild(variant, root.GetChild(relative_path));
221+
VariantUpdateChild(&accum, relative_path, child);
222+
return accum;
223+
});
224+
VariantUpdateChild(variant, root, result);
225+
}
226+
227+
void InMemoryPersistenceStorageEngine::PruneCache(
228+
const Path& root, const PruneForestRef& prune_forest) {
229+
PruneVariant(root, prune_forest, &server_cache_);
230+
}
231+
160232
bool InMemoryPersistenceStorageEngine::BeginTransaction() {
161233
FIREBASE_DEV_ASSERT_MESSAGE(!inside_transaction_,
162-
"runInTransaction called when an existing "
234+
"RunInTransaction called when an existing "
163235
"transaction is already in progress.");
164236
logger_->LogDebug("Starting transaction.");
165237
inside_transaction_ = true;
166238
return true;
167239
}
168240

169241
void InMemoryPersistenceStorageEngine::EndTransaction() {
242+
FIREBASE_DEV_ASSERT_MESSAGE(
243+
inside_transaction_, "EndTransaction called when not in a transaction");
170244
inside_transaction_ = false;
171245
logger_->LogDebug("Transaction completed.");
172246
}

0 commit comments

Comments
 (0)