Skip to content

Commit 68a2ce1

Browse files
committed
refactor(hermes): state->cache downcasting
1 parent 62d189e commit 68a2ce1

File tree

4 files changed

+102
-103
lines changed

4 files changed

+102
-103
lines changed

hermes/src/aggregate.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ use {
2323
state::{
2424
benchmarks::Benchmarks,
2525
cache::{
26-
AggregateCache,
26+
Cache,
2727
MessageState,
2828
MessageStateFilter,
2929
},
@@ -336,7 +336,7 @@ async fn get_verified_price_feeds<S>(
336336
request_time: RequestTime,
337337
) -> Result<PriceFeedsWithUpdateData>
338338
where
339-
S: AggregateCache,
339+
S: Cache,
340340
{
341341
let messages = state
342342
.fetch_message_states(
@@ -396,7 +396,7 @@ pub async fn get_price_feeds_with_update_data<S>(
396396
request_time: RequestTime,
397397
) -> Result<PriceFeedsWithUpdateData>
398398
where
399-
S: AggregateCache,
399+
S: Cache,
400400
S: Benchmarks,
401401
{
402402
match get_verified_price_feeds(state, price_ids, request_time.clone()).await {
@@ -412,7 +412,7 @@ where
412412

413413
pub async fn get_price_feed_ids<S>(state: &S) -> HashSet<PriceIdentifier>
414414
where
415-
S: AggregateCache,
415+
S: Cache,
416416
{
417417
state
418418
.message_state_keys()
@@ -468,10 +468,7 @@ mod test {
468468
Accumulator,
469469
},
470470
hashers::keccak256_160::Keccak160,
471-
messages::{
472-
Message,
473-
PriceFeedMessage,
474-
},
471+
messages::PriceFeedMessage,
475472
wire::v1::{
476473
AccumulatorUpdateData,
477474
Proof,

hermes/src/aggregate/wormhole_merkle.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use {
77
crate::{
88
network::wormhole::VaaBytes,
99
state::cache::{
10-
AggregateCache,
10+
Cache,
1111
MessageState,
1212
},
1313
},
@@ -70,14 +70,14 @@ impl From<MessageState> for RawMessageWithMerkleProof {
7070
}
7171

7272
pub async fn store_wormhole_merkle_verified_message<S>(
73-
store: &S,
73+
state: &S,
7474
root: WormholeMerkleRoot,
7575
vaa: VaaBytes,
7676
) -> Result<()>
7777
where
78-
S: AggregateCache,
78+
S: Cache,
7979
{
80-
store
80+
state
8181
.store_wormhole_merkle_state(WormholeMerkleState { root, vaa })
8282
.await?;
8383
Ok(())

hermes/src/state.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! This module contains the global state of the application.
22
33
use {
4-
self::cache::Cache,
4+
self::cache::CacheState,
55
crate::{
66
aggregate::{
77
AggregateState,
@@ -31,7 +31,7 @@ pub mod cache;
3131
pub struct State {
3232
/// Storage is a short-lived cache of the state of all the updates that have been passed to the
3333
/// store.
34-
pub cache: Cache,
34+
pub cache: CacheState,
3535

3636
/// Sequence numbers of lately observed Vaas. Store uses this set
3737
/// to ignore the previously observed Vaas as a performance boost.
@@ -64,7 +64,7 @@ impl State {
6464
) -> Arc<Self> {
6565
let mut metrics_registry = Registry::default();
6666
Arc::new(Self {
67-
cache: Cache::new(cache_size),
67+
cache: CacheState::new(cache_size),
6868
observed_vaa_seqs: RwLock::new(Default::default()),
6969
guardian_set: RwLock::new(Default::default()),
7070
api_update_tx: update_tx,

hermes/src/state/cache.rs

Lines changed: 90 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use {
2+
super::State,
23
crate::aggregate::{
34
wormhole_merkle::WormholeMerkleState,
45
AccumulatorMessages,
@@ -96,79 +97,42 @@ pub enum MessageStateFilter {
9697
Only(MessageType),
9798
}
9899

99-
pub struct Cache {
100-
/// Accumulator messages cache
101-
///
102-
/// We do not write to this cache much, so we can use a simple RwLock instead of a DashMap.
103-
accumulator_messages_cache: Arc<RwLock<BTreeMap<Slot, AccumulatorMessages>>>,
104-
105-
/// Wormhole merkle state cache
106-
///
107-
/// We do not write to this cache much, so we can use a simple RwLock instead of a DashMap.
108-
wormhole_merkle_state_cache: Arc<RwLock<BTreeMap<Slot, WormholeMerkleState>>>,
100+
/// A Cache of AccumulatorMessage by slot. We do not write to this cache much, so we can use a simple RwLock instead of a DashMap.
101+
type AccumulatorMessagesCache = Arc<RwLock<BTreeMap<Slot, AccumulatorMessages>>>;
109102

110-
message_cache: Arc<RwLock<HashMap<MessageStateKey, BTreeMap<MessageStateTime, MessageState>>>>,
111-
cache_size: u64,
112-
}
113-
114-
async fn retrieve_message_state(
115-
cache: &Cache,
116-
key: MessageStateKey,
117-
request_time: RequestTime,
118-
) -> Option<MessageState> {
119-
match cache.message_cache.read().await.get(&key) {
120-
Some(key_cache) => {
121-
match request_time {
122-
RequestTime::Latest => key_cache.last_key_value().map(|(_, v)| v).cloned(),
123-
RequestTime::FirstAfter(time) => {
124-
// If the requested time is before the first element in the vector, we are
125-
// not sure that the first element is the closest one.
126-
if let Some((_, oldest_record_value)) = key_cache.first_key_value() {
127-
if time < oldest_record_value.time().publish_time {
128-
return None;
129-
}
130-
}
103+
/// A Cache of WormholeMerkleState by slot. We do not write to this cache much, so we can use a simple RwLock instead of a DashMap.
104+
type WormholeMerkleStateCache = Arc<RwLock<BTreeMap<Slot, WormholeMerkleState>>>;
131105

132-
let lookup_time = MessageStateTime {
133-
publish_time: time,
134-
slot: 0,
135-
};
106+
/// A Cache of `Time<->MessageState` by feed id.
107+
type MessageCache = Arc<RwLock<HashMap<MessageStateKey, BTreeMap<MessageStateTime, MessageState>>>>;
136108

137-
// Get the first element that is greater than or equal to the lookup time.
138-
key_cache
139-
.lower_bound(Bound::Included(&lookup_time))
140-
.peek_next()
141-
.map(|(_, v)| v)
142-
.cloned()
143-
}
144-
RequestTime::AtSlot(slot) => {
145-
// Get the state with slot equal to the lookup slot.
146-
key_cache
147-
.iter()
148-
.rev() // Usually the slot lies at the end of the map
149-
.find(|(k, _)| k.slot == slot)
150-
.map(|(_, v)| v)
151-
.cloned()
152-
}
153-
}
154-
}
155-
None => None,
156-
}
109+
/// A collection of caches for various program state.
110+
pub struct CacheState {
111+
accumulator_messages_cache: AccumulatorMessagesCache,
112+
wormhole_merkle_state_cache: WormholeMerkleStateCache,
113+
message_cache: MessageCache,
114+
cache_size: u64,
157115
}
158116

159-
impl Cache {
160-
pub fn new(cache_size: u64) -> Self {
117+
impl CacheState {
118+
pub fn new(size: u64) -> Self {
161119
Self {
162-
message_cache: Arc::new(RwLock::new(HashMap::new())),
163-
accumulator_messages_cache: Arc::new(RwLock::new(BTreeMap::new())),
120+
accumulator_messages_cache: Arc::new(RwLock::new(BTreeMap::new())),
164121
wormhole_merkle_state_cache: Arc::new(RwLock::new(BTreeMap::new())),
165-
cache_size,
122+
message_cache: Arc::new(RwLock::new(HashMap::new())),
123+
cache_size: size,
166124
}
167125
}
168126
}
169127

170-
#[async_trait::async_trait]
171-
pub trait AggregateCache {
128+
/// Allow downcasting State into CacheState for functions that depend on the `Cache` service.
129+
impl<'a> From<&'a State> for &'a CacheState {
130+
fn from(state: &'a State) -> &'a CacheState {
131+
&state.cache
132+
}
133+
}
134+
135+
pub trait Cache {
172136
async fn message_state_keys(&self) -> Vec<MessageStateKey>;
173137
async fn store_message_states(&self, message_states: Vec<MessageState>) -> Result<()>;
174138
async fn prune_removed_keys(&self, current_keys: HashSet<MessageStateKey>);
@@ -190,10 +154,13 @@ pub trait AggregateCache {
190154
async fn fetch_wormhole_merkle_state(&self, slot: Slot) -> Result<Option<WormholeMerkleState>>;
191155
}
192156

193-
#[async_trait::async_trait]
194-
impl AggregateCache for crate::state::State {
157+
impl<T> Cache for T
158+
where
159+
for<'a> &'a T: Into<&'a CacheState>,
160+
T: Sync,
161+
{
195162
async fn message_state_keys(&self) -> Vec<MessageStateKey> {
196-
self.cache
163+
self.into()
197164
.message_cache
198165
.read()
199166
.await
@@ -203,7 +170,7 @@ impl AggregateCache for crate::state::State {
203170
}
204171

205172
async fn store_message_states(&self, message_states: Vec<MessageState>) -> Result<()> {
206-
let mut message_cache = self.cache.message_cache.write().await;
173+
let mut message_cache = self.into().message_cache.write().await;
207174

208175
for message_state in message_states {
209176
let key = message_state.key();
@@ -212,7 +179,7 @@ impl AggregateCache for crate::state::State {
212179
cache.insert(time, message_state);
213180

214181
// Remove the earliest message states if the cache size is exceeded
215-
while cache.len() > self.cache.cache_size as usize {
182+
while cache.len() > self.into().cache_size as usize {
216183
cache.pop_first();
217184
}
218185
}
@@ -227,7 +194,7 @@ impl AggregateCache for crate::state::State {
227194
/// lose the cache for that key and cannot retrieve it for historical
228195
/// price queries.
229196
async fn prune_removed_keys(&self, current_keys: HashSet<MessageStateKey>) {
230-
let mut message_cache = self.cache.message_cache.write().await;
197+
let mut message_cache = self.into().message_cache.write().await;
231198

232199
// Sometimes, some keys are removed from the accumulator. We track which keys are not
233200
// present in the message states and remove them from the cache.
@@ -262,7 +229,7 @@ impl AggregateCache for crate::state::State {
262229
feed_id: id,
263230
type_: message_type,
264231
};
265-
retrieve_message_state(&self.cache, key, request_time.clone())
232+
retrieve_message_state(self.into(), key, request_time.clone())
266233
})
267234
}))
268235
.await
@@ -275,60 +242,95 @@ impl AggregateCache for crate::state::State {
275242
&self,
276243
accumulator_messages: AccumulatorMessages,
277244
) -> Result<()> {
278-
let mut cache = self.cache.accumulator_messages_cache.write().await;
245+
let mut cache = self.into().accumulator_messages_cache.write().await;
279246
cache.insert(accumulator_messages.slot, accumulator_messages);
280-
while cache.len() > self.cache.cache_size as usize {
247+
while cache.len() > self.into().cache_size as usize {
281248
cache.pop_first();
282249
}
283250
Ok(())
284251
}
285252

286253
async fn fetch_accumulator_messages(&self, slot: Slot) -> Result<Option<AccumulatorMessages>> {
287-
let cache = self.cache.accumulator_messages_cache.read().await;
254+
let cache = self.into().accumulator_messages_cache.read().await;
288255
Ok(cache.get(&slot).cloned())
289256
}
290257

291258
async fn store_wormhole_merkle_state(
292259
&self,
293260
wormhole_merkle_state: WormholeMerkleState,
294261
) -> Result<()> {
295-
let mut cache = self.cache.wormhole_merkle_state_cache.write().await;
262+
let mut cache = self.into().wormhole_merkle_state_cache.write().await;
296263
cache.insert(wormhole_merkle_state.root.slot, wormhole_merkle_state);
297-
while cache.len() > self.cache.cache_size as usize {
264+
while cache.len() > self.into().cache_size as usize {
298265
cache.pop_first();
299266
}
300267
Ok(())
301268
}
302269

303270
async fn fetch_wormhole_merkle_state(&self, slot: Slot) -> Result<Option<WormholeMerkleState>> {
304-
let cache = self.cache.wormhole_merkle_state_cache.read().await;
271+
let cache = self.into().wormhole_merkle_state_cache.read().await;
305272
Ok(cache.get(&slot).cloned())
306273
}
307274
}
308275

276+
async fn retrieve_message_state(
277+
cache: &CacheState,
278+
key: MessageStateKey,
279+
request_time: RequestTime,
280+
) -> Option<MessageState> {
281+
match cache.message_cache.read().await.get(&key) {
282+
Some(key_cache) => {
283+
match request_time {
284+
RequestTime::Latest => key_cache.last_key_value().map(|(_, v)| v).cloned(),
285+
RequestTime::FirstAfter(time) => {
286+
// If the requested time is before the first element in the vector, we are
287+
// not sure that the first element is the closest one.
288+
if let Some((_, oldest_record_value)) = key_cache.first_key_value() {
289+
if time < oldest_record_value.time().publish_time {
290+
return None;
291+
}
292+
}
293+
294+
let lookup_time = MessageStateTime {
295+
publish_time: time,
296+
slot: 0,
297+
};
298+
299+
// Get the first element that is greater than or equal to the lookup time.
300+
key_cache
301+
.lower_bound(Bound::Included(&lookup_time))
302+
.peek_next()
303+
.map(|(_, v)| v)
304+
.cloned()
305+
}
306+
RequestTime::AtSlot(slot) => {
307+
// Get the state with slot equal to the lookup slot.
308+
key_cache
309+
.iter()
310+
.rev() // Usually the slot lies at the end of the map
311+
.find(|(k, _)| k.slot == slot)
312+
.map(|(_, v)| v)
313+
.cloned()
314+
}
315+
}
316+
}
317+
None => None,
318+
}
319+
}
320+
309321
#[cfg(test)]
310322
mod test {
311323
use {
312324
super::*,
313325
crate::{
314-
aggregate::{
315-
wormhole_merkle::{
316-
WormholeMerkleMessageProof,
317-
WormholeMerkleState,
318-
},
319-
AccumulatorMessages,
320-
ProofSet,
321-
},
326+
aggregate::wormhole_merkle::WormholeMerkleMessageProof,
322327
state::test::setup_state,
323328
},
324329
pyth_sdk::UnixTimestamp,
325330
pythnet_sdk::{
326331
accumulators::merkle::MerklePath,
327332
hashers::keccak256_160::Keccak160,
328-
messages::{
329-
Message,
330-
PriceFeedMessage,
331-
},
333+
messages::PriceFeedMessage,
332334
wire::v1::WormholeMerkleRoot,
333335
},
334336
};
@@ -369,7 +371,7 @@ mod test {
369371
slot: Slot,
370372
) -> MessageState
371373
where
372-
S: AggregateCache,
374+
S: Cache,
373375
{
374376
let message_state = create_dummy_price_feed_message_state(feed_id, publish_time, slot);
375377
state

0 commit comments

Comments
 (0)