Skip to content

Commit cbe8af9

Browse files
committed
refactor: split in multiple functions
1 parent 4cb2fba commit cbe8af9

File tree

1 file changed

+22
-13
lines changed

1 file changed

+22
-13
lines changed

common/src/tap/checks/value_check.rs

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ impl MinimumValue {
6565
}
6666
});
6767

68+
// we use two different handles because in case one channel breaks we still have the other
6869
let query_handle = tokio::spawn(async move {
6970
loop {
7071
let query = rx_query.recv().await;
@@ -96,33 +97,41 @@ impl Drop for MinimumValue {
9697
}
9798
}
9899

99-
#[async_trait::async_trait]
100-
impl Check for MinimumValue {
101-
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
102-
// get key
103-
let key = &receipt.signed_receipt().signature;
100+
impl MinimumValue {
101+
fn get_agora_query(&self, query_id: &Signature) -> Option<AgoraQuery> {
102+
self.query_ids.lock().unwrap().remove(query_id)
103+
}
104104

105+
fn get_expected_value(&self, query_id: &Signature) -> anyhow::Result<u128> {
105106
// get query from key
106107
let agora_query = self
107-
.query_ids
108-
.lock()
109-
.unwrap()
110-
.remove(key)
108+
.get_agora_query(query_id)
111109
.ok_or(anyhow!("No query found"))?;
112110

113111
// get agora model for the allocation_id
114112
let mut cache = self.cost_model_cache.lock().unwrap();
115-
116113
// on average, we'll have zero or one model
117114
let models = cache.get_mut(&agora_query.deployment_id);
118115

119-
// get value
120-
let value = receipt.signed_receipt().message.value;
121-
122116
let expected_value = models
123117
.map(|cache| cache.cost(&agora_query))
124118
.unwrap_or_default();
125119

120+
Ok(expected_value)
121+
}
122+
}
123+
124+
#[async_trait::async_trait]
125+
impl Check for MinimumValue {
126+
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
127+
// get key
128+
let key = &receipt.signed_receipt().signature;
129+
130+
let expected_value = self.get_expected_value(key)?;
131+
132+
// get value
133+
let value = receipt.signed_receipt().message.value;
134+
126135
let should_accept = value >= expected_value;
127136

128137
tracing::trace!(

0 commit comments

Comments
 (0)