Skip to content

Commit 8d954bf

Browse files
committed
refactor: split in multiple functions
1 parent 22ff07b commit 8d954bf

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

common/src/tap/checks/value_check.rs

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ impl MinimumValue {
6868
}
6969
});
7070

71+
// we use two different handles because in case one channel breaks we still have the other
7172
let query_handle = tokio::spawn(async move {
7273
loop {
7374
let query = rx_query.recv().await;
@@ -99,34 +100,41 @@ impl Drop for MinimumValue {
99100
}
100101
}
101102

102-
#[async_trait::async_trait]
103-
impl Check for MinimumValue {
104-
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
105-
// get key
106-
let key = &receipt.signed_receipt().signature.get_signature_bytes();
103+
impl MinimumValue {
104+
fn get_agora_query(&self, query_id: &SignatureBytes) -> Option<AgoraQuery> {
105+
self.query_ids.lock().unwrap().remove(query_id)
106+
}
107107

108+
fn get_expected_value(&self, query_id: &SignatureBytes) -> anyhow::Result<u128> {
108109
// get query from key
109110
let agora_query = self
110-
.query_ids
111-
.lock()
112-
.unwrap()
113-
.remove(key)
114-
.ok_or(anyhow!("No query found"))
115-
.map_err(CheckError::Failed)?;
111+
.get_agora_query(query_id)
112+
.ok_or(anyhow!("No query found"))?;
116113

117114
// get agora model for the allocation_id
118115
let mut cache = self.cost_model_cache.lock().unwrap();
119-
120116
// on average, we'll have zero or one model
121117
let models = cache.get_mut(&agora_query.deployment_id);
122118

123-
// get value
124-
let value = receipt.signed_receipt().message.value;
125-
126119
let expected_value = models
127120
.map(|cache| cache.cost(&agora_query))
128121
.unwrap_or_default();
129122

123+
Ok(expected_value)
124+
}
125+
}
126+
127+
#[async_trait::async_trait]
128+
impl Check for MinimumValue {
129+
async fn check(&self, receipt: &ReceiptWithState<Checking>) -> CheckResult {
130+
// get key
131+
let key = &receipt.signed_receipt().signature.get_signature_bytes();
132+
133+
let expected_value = self.get_expected_value(key).map_err(CheckError::Failed)?;
134+
135+
// get value
136+
let value = receipt.signed_receipt().message.value;
137+
130138
let should_accept = value >= expected_value;
131139

132140
tracing::trace!(

0 commit comments

Comments
 (0)