@@ -9,9 +9,11 @@ use std::{
99 cmp:: min,
1010 collections:: HashMap ,
1111 sync:: { Arc , Mutex } ,
12+ time:: Duration ,
1213} ;
1314use thegraph_core:: DeploymentId ;
14- use tokio:: { select, sync:: mpsc:: Receiver , task:: JoinHandle } ;
15+ use tokio:: { sync:: mpsc:: Receiver , task:: JoinHandle } ;
16+ use ttl_cache:: TtlCache ;
1517
1618use tap_core:: {
1719 receipt:: {
@@ -23,67 +25,77 @@ use tap_core::{
2325} ;
2426
2527pub struct MinimumValue {
26- cost_model_cache : Arc < Mutex < HashMap < DeploymentId , CostModel > > > ,
28+ cost_model_cache : Arc < Mutex < HashMap < DeploymentId , CostModelCache > > > ,
2729 query_ids : Arc < Mutex < HashMap < SignatureBytes , AgoraQuery > > > ,
28- handle : JoinHandle < ( ) > ,
30+ model_handle : JoinHandle < ( ) > ,
31+ query_handle : JoinHandle < ( ) > ,
2932}
3033
3134impl MinimumValue {
3235 pub fn new (
3336 mut rx_cost_model : Receiver < CostModelSource > ,
3437 mut rx_query : Receiver < AgoraQuery > ,
3538 ) -> Self {
36- let cost_model_cache = Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
39+ let cost_model_cache = Arc :: new ( Mutex :: new ( HashMap :: < DeploymentId , CostModelCache > :: new ( ) ) ) ;
3740 let query_ids = Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ;
3841 let cache = cost_model_cache. clone ( ) ;
3942 let query_ids_clone = query_ids. clone ( ) ;
40- let handle = tokio:: spawn ( async move {
43+ let model_handle = tokio:: spawn ( async move {
4144 loop {
42- select ! {
43- model = rx_cost_model . recv ( ) => {
44- match model {
45- Some ( value ) => {
46- let deployment_id = value . deployment_id ;
47-
48- match compile_cost_model ( value) {
49- Ok ( value ) => {
50- // todo keep track of the last X models
51- cache . lock ( ) . unwrap ( ) . insert ( deployment_id , value) ;
52- }
53- Err ( err ) => {
54- tracing :: error! (
55- "Error while compiling cost model for deployment id {}. Error: {}" ,
56- deployment_id , err
57- )
58- }
45+ let model = rx_cost_model . recv ( ) . await ;
46+ match model {
47+ Some ( value ) => {
48+ let deployment_id = value . deployment_id ;
49+
50+ if let Some ( query ) = cache . lock ( ) . unwrap ( ) . get_mut ( & deployment_id ) {
51+ let _ = query . insert_model ( value) ;
52+ } else {
53+ match CostModelCache :: new ( value ) {
54+ Ok ( value) => {
55+ cache . lock ( ) . unwrap ( ) . insert ( deployment_id , value ) ;
56+ }
57+ Err ( err ) => {
58+ tracing :: error! (
59+ "Error while compiling cost model for deployment id {}. Error: {}" ,
60+ deployment_id , err
61+ )
5962 }
6063 }
61- None => continue ,
6264 }
6365 }
64- query = rx_query. recv( ) => {
65- match query {
66- Some ( query) => {
67- query_ids_clone. lock( ) . unwrap( ) . insert( query. signature. get_signature_bytes( ) , query) ;
68- } ,
69- None => continue ,
70- }
66+ None => continue ,
67+ }
68+ }
69+ } ) ;
70+
71+ let query_handle = tokio:: spawn ( async move {
72+ loop {
73+ let query = rx_query. recv ( ) . await ;
74+ match query {
75+ Some ( query) => {
76+ query_ids_clone
77+ . lock ( )
78+ . unwrap ( )
79+ . insert ( query. signature . get_signature_bytes ( ) , query) ;
7180 }
81+ None => continue ,
7282 }
7383 }
7484 } ) ;
7585
7686 Self {
7787 cost_model_cache,
78- handle ,
88+ model_handle ,
7989 query_ids,
90+ query_handle,
8091 }
8192 }
8293}
8394
8495impl Drop for MinimumValue {
8596 fn drop ( & mut self ) {
86- self . handle . abort ( ) ;
97+ self . model_handle . abort ( ) ;
98+ self . query_handle . abort ( ) ;
8799 }
88100}
89101
@@ -103,32 +115,16 @@ impl Check for MinimumValue {
103115 . map_err ( CheckError :: Failed ) ?;
104116
105117 // get agora model for the allocation_id
106- let cache = self . cost_model_cache . lock ( ) . unwrap ( ) ;
118+ let mut cache = self . cost_model_cache . lock ( ) . unwrap ( ) ;
107119
108120 // on average, we'll have zero or one model
109- let models = cache
110- . get ( & agora_query. deployment_id )
111- . map ( |model| vec ! [ model] )
112- . unwrap_or_default ( ) ;
121+ let models = cache. get_mut ( & agora_query. deployment_id ) ;
113122
114123 // get value
115124 let value = receipt. signed_receipt ( ) . message . value ;
116125
117126 let expected_value = models
118- . into_iter ( )
119- . fold ( None , |acc, model| {
120- let value = model
121- . cost ( & agora_query. query , & agora_query. variables )
122- . ok ( )
123- . map ( |fee| fee. to_u128 ( ) . unwrap_or_default ( ) )
124- . unwrap_or_default ( ) ;
125- if let Some ( acc) = acc {
126- // return the minimum value of the cache list
127- Some ( min ( acc, value) )
128- } else {
129- Some ( value)
130- }
131- } )
127+ . map ( |cache| cache. cost ( & agora_query) )
132128 . unwrap_or_default ( ) ;
133129
134130 let should_accept = value >= expected_value;
@@ -151,11 +147,11 @@ impl Check for MinimumValue {
151147 }
152148}
153149
154- fn compile_cost_model ( src : CostModelSource ) -> Result < CostModel , String > {
150+ fn compile_cost_model ( src : CostModelSource ) -> anyhow :: Result < CostModel > {
155151 if src. model . len ( ) > ( 1 << 16 ) {
156- return Err ( "CostModelTooLarge" . into ( ) ) ;
152+ return Err ( anyhow ! ( "CostModelTooLarge" ) ) ;
157153 }
158- let model = CostModel :: compile ( & src. model , & src. variables ) . map_err ( |err| err . to_string ( ) ) ?;
154+ let model = CostModel :: compile ( & src. model , & src. variables ) ?;
159155 Ok ( model)
160156}
161157
@@ -166,9 +162,68 @@ pub struct AgoraQuery {
166162 variables : String ,
167163}
168164
169- #[ derive( Eq , Hash , PartialEq ) ]
165+ #[ derive( Clone , Eq , Hash , PartialEq ) ]
170166pub struct CostModelSource {
171167 deployment_id : DeploymentId ,
172168 model : String ,
173169 variables : String ,
174170}
171+
172+ pub struct CostModelCache {
173+ models : TtlCache < CostModelSource , CostModel > ,
174+ latest_model : CostModel ,
175+ latest_source : CostModelSource ,
176+ }
177+
178+ impl CostModelCache {
179+ pub fn new ( source : CostModelSource ) -> anyhow:: Result < Self > {
180+ let model = compile_cost_model ( source. clone ( ) ) ?;
181+ Ok ( Self {
182+ latest_model : model,
183+ latest_source : source,
184+ // arbitrary number of models copy
185+ models : TtlCache :: new ( 10 ) ,
186+ } )
187+ }
188+
189+ fn insert_model ( & mut self , source : CostModelSource ) -> anyhow:: Result < ( ) > {
190+ if source != self . latest_source {
191+ let model = compile_cost_model ( source. clone ( ) ) ?;
192+ // update latest and insert into ttl the old model
193+ let old_model = std:: mem:: replace ( & mut self . latest_model , model) ;
194+ self . latest_source = source. clone ( ) ;
195+
196+ self . models
197+ // arbitrary cache duration
198+ . insert ( source, old_model, Duration :: from_secs ( 60 ) ) ;
199+ }
200+ Ok ( ( ) )
201+ }
202+
203+ fn get_models ( & mut self ) -> Vec < & CostModel > {
204+ let mut values: Vec < & CostModel > = self . models . iter ( ) . map ( |( _, v) | v) . collect ( ) ;
205+ values. push ( & self . latest_model ) ;
206+ values
207+ }
208+
209+ fn cost ( & mut self , query : & AgoraQuery ) -> u128 {
210+ let models = self . get_models ( ) ;
211+
212+ models
213+ . into_iter ( )
214+ . fold ( None , |acc, model| {
215+ let value = model
216+ . cost ( & query. query , & query. variables )
217+ . ok ( )
218+ . map ( |fee| fee. to_u128 ( ) . unwrap_or_default ( ) )
219+ . unwrap_or_default ( ) ;
220+ if let Some ( acc) = acc {
221+ // return the minimum value of the cache list
222+ Some ( min ( acc, value) )
223+ } else {
224+ Some ( value)
225+ }
226+ } )
227+ . unwrap_or_default ( )
228+ }
229+ }
0 commit comments