@@ -7,9 +7,11 @@ use std::{
77 cmp:: min,
88 collections:: HashMap ,
99 sync:: { Arc , Mutex } ,
10+ time:: Duration ,
1011} ;
1112use thegraph:: types:: DeploymentId ;
12- use tokio:: { select, sync:: mpsc:: Receiver , task:: JoinHandle } ;
13+ use tokio:: { sync:: mpsc:: Receiver , task:: JoinHandle } ;
14+ use ttl_cache:: TtlCache ;
1315
1416use anyhow:: anyhow;
1517use cost_model:: CostModel ;
@@ -20,67 +22,77 @@ use tap_core::receipt::{
2022} ;
2123
2224pub struct MinimumValue {
23- cost_model_cache : Arc < Mutex < HashMap < DeploymentId , CostModel > > > ,
25+ cost_model_cache : Arc < Mutex < HashMap < DeploymentId , CostModelCache > > > ,
2426 query_ids : Arc < Mutex < HashMap < Signature , AgoraQuery > > > ,
25- handle : JoinHandle < ( ) > ,
27+ model_handle : JoinHandle < ( ) > ,
28+ query_handle : JoinHandle < ( ) > ,
2629}
2730
2831impl MinimumValue {
2932 pub fn new (
3033 mut rx_cost_model : Receiver < CostModelSource > ,
3134 mut rx_query : Receiver < AgoraQuery > ,
3235 ) -> Self {
33- let cost_model_cache = Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
36+ let cost_model_cache = Arc :: new ( Mutex :: new ( HashMap :: < DeploymentId , CostModelCache > :: new ( ) ) ) ;
3437 let query_ids = Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
3538 let cache = cost_model_cache. clone ( ) ;
3639 let query_ids_clone = query_ids. clone ( ) ;
37- let handle = tokio:: spawn ( async move {
40+ let model_handle = tokio:: spawn ( async move {
3841 loop {
39- select ! {
40- model = rx_cost_model . recv ( ) => {
41- match model {
42- Some ( value ) => {
43- let deployment_id = value . deployment_id ;
44-
45- match compile_cost_model ( value) {
46- Ok ( value ) => {
47- // todo keep track of the last X models
48- cache . lock ( ) . unwrap ( ) . insert ( deployment_id , value) ;
49- }
50- Err ( err ) => {
51- tracing :: error! (
52- "Error while compiling cost model for deployment id {}. Error: {}" ,
53- deployment_id , err
54- )
55- }
42+ let model = rx_cost_model . recv ( ) . await ;
43+ match model {
44+ Some ( value ) => {
45+ let deployment_id = value . deployment_id ;
46+
47+ if let Some ( query ) = cache . lock ( ) . unwrap ( ) . get_mut ( & deployment_id ) {
48+ let _ = query . insert_model ( value) ;
49+ } else {
50+ match CostModelCache :: new ( value ) {
51+ Ok ( value) => {
52+ cache . lock ( ) . unwrap ( ) . insert ( deployment_id , value ) ;
53+ }
54+ Err ( err ) => {
55+ tracing :: error! (
56+ "Error while compiling cost model for deployment id {}. Error: {}" ,
57+ deployment_id , err
58+ )
5659 }
5760 }
58- None => continue ,
5961 }
6062 }
61- query = rx_query. recv( ) => {
62- match query {
63- Some ( query) => {
64- query_ids_clone. lock( ) . unwrap( ) . insert( query. signature, query) ;
65- } ,
66- None => continue ,
67- }
63+ None => continue ,
64+ }
65+ }
66+ } ) ;
67+
68+ let query_handle = tokio:: spawn ( async move {
69+ loop {
70+ let query = rx_query. recv ( ) . await ;
71+ match query {
72+ Some ( query) => {
73+ query_ids_clone
74+ . lock ( )
75+ . unwrap ( )
76+ . insert ( query. signature , query) ;
6877 }
78+ None => continue ,
6979 }
7080 }
7181 } ) ;
7282
7383 Self {
7484 cost_model_cache,
75- handle ,
85+ model_handle ,
7686 query_ids,
87+ query_handle,
7788 }
7889 }
7990}
8091
8192impl Drop for MinimumValue {
8293 fn drop ( & mut self ) {
83- self . handle . abort ( ) ;
94+ self . model_handle . abort ( ) ;
95+ self . query_handle . abort ( ) ;
8496 }
8597}
8698
@@ -99,32 +111,16 @@ impl Check for MinimumValue {
99111 . ok_or ( anyhow ! ( "No query found" ) ) ?;
100112
101113 // get agora model for the allocation_id
102- let cache = self . cost_model_cache . lock ( ) . unwrap ( ) ;
114+ let mut cache = self . cost_model_cache . lock ( ) . unwrap ( ) ;
103115
104116 // on average, we'll have zero or one model
105- let models = cache
106- . get ( & agora_query. deployment_id )
107- . map ( |model| vec ! [ model] )
108- . unwrap_or_default ( ) ;
117+ let models = cache. get_mut ( & agora_query. deployment_id ) ;
109118
110119 // get value
111120 let value = receipt. signed_receipt ( ) . message . value ;
112121
113122 let expected_value = models
114- . into_iter ( )
115- . fold ( None , |acc, model| {
116- let value = model
117- . cost ( & agora_query. query , & agora_query. variables )
118- . ok ( )
119- . map ( |fee| fee. to_u128 ( ) . unwrap_or_default ( ) )
120- . unwrap_or_default ( ) ;
121- if let Some ( acc) = acc {
122- // return the minimum value of the cache list
123- Some ( min ( acc, value) )
124- } else {
125- Some ( value)
126- }
127- } )
123+ . map ( |cache| cache. cost ( & agora_query) )
128124 . unwrap_or_default ( ) ;
129125
130126 let should_accept = value >= expected_value;
@@ -147,11 +143,11 @@ impl Check for MinimumValue {
147143 }
148144}
149145
150- fn compile_cost_model ( src : CostModelSource ) -> Result < CostModel , String > {
146+ fn compile_cost_model ( src : CostModelSource ) -> anyhow :: Result < CostModel > {
151147 if src. model . len ( ) > ( 1 << 16 ) {
152- return Err ( "CostModelTooLarge" . into ( ) ) ;
148+ return Err ( anyhow ! ( "CostModelTooLarge" ) ) ;
153149 }
154- let model = CostModel :: compile ( & src. model , & src. variables ) . map_err ( |err| err . to_string ( ) ) ?;
150+ let model = CostModel :: compile ( & src. model , & src. variables ) ?;
155151 Ok ( model)
156152}
157153
@@ -162,9 +158,68 @@ pub struct AgoraQuery {
162158 variables : String ,
163159}
164160
165- #[ derive( Eq , Hash , PartialEq ) ]
161+ #[ derive( Clone , Eq , Hash , PartialEq ) ]
166162pub struct CostModelSource {
167163 deployment_id : DeploymentId ,
168164 model : String ,
169165 variables : String ,
170166}
167+
168+ pub struct CostModelCache {
169+ models : TtlCache < CostModelSource , CostModel > ,
170+ latest_model : CostModel ,
171+ latest_source : CostModelSource ,
172+ }
173+
174+ impl CostModelCache {
175+ pub fn new ( source : CostModelSource ) -> anyhow:: Result < Self > {
176+ let model = compile_cost_model ( source. clone ( ) ) ?;
177+ Ok ( Self {
178+ latest_model : model,
179+ latest_source : source,
180+ // arbitrary number of models copy
181+ models : TtlCache :: new ( 10 ) ,
182+ } )
183+ }
184+
185+ fn insert_model ( & mut self , source : CostModelSource ) -> anyhow:: Result < ( ) > {
186+ if source != self . latest_source {
187+ let model = compile_cost_model ( source. clone ( ) ) ?;
188+ // update latest and insert into ttl the old model
189+ let old_model = std:: mem:: replace ( & mut self . latest_model , model) ;
190+ self . latest_source = source. clone ( ) ;
191+
192+ self . models
193+ // arbitrary cache duration
194+ . insert ( source, old_model, Duration :: from_secs ( 60 ) ) ;
195+ }
196+ Ok ( ( ) )
197+ }
198+
199+ fn get_models ( & mut self ) -> Vec < & CostModel > {
200+ let mut values: Vec < & CostModel > = self . models . iter ( ) . map ( |( _, v) | v) . collect ( ) ;
201+ values. push ( & self . latest_model ) ;
202+ values
203+ }
204+
205+ fn cost ( & mut self , query : & AgoraQuery ) -> u128 {
206+ let models = self . get_models ( ) ;
207+
208+ models
209+ . into_iter ( )
210+ . fold ( None , |acc, model| {
211+ let value = model
212+ . cost ( & query. query , & query. variables )
213+ . ok ( )
214+ . map ( |fee| fee. to_u128 ( ) . unwrap_or_default ( ) )
215+ . unwrap_or_default ( ) ;
216+ if let Some ( acc) = acc {
217+ // return the minimum value of the cache list
218+ Some ( min ( acc, value) )
219+ } else {
220+ Some ( value)
221+ }
222+ } )
223+ . unwrap_or_default ( )
224+ }
225+ }
0 commit comments