@@ -5,18 +5,15 @@ use anyhow::anyhow;
55use bigdecimal:: ToPrimitive ;
66use cost_model:: CostModel ;
77use sqlx:: { postgres:: PgListener , PgPool } ;
8- use tracing:: error;
98use std:: {
109 cmp:: min,
1110 collections:: HashMap ,
1211 sync:: { Arc , Mutex } ,
1312 time:: Duration ,
1413} ;
1514use thegraph_core:: DeploymentId ;
16- use tokio:: {
17- sync:: mpsc:: { Receiver , Sender } ,
18- task:: JoinHandle ,
19- } ;
15+ use tokio:: task:: JoinHandle ;
16+ use tracing:: error;
2017use ttl_cache:: TtlCache ;
2118
2219use tap_core:: receipt:: {
@@ -30,60 +27,25 @@ pub struct MinimumValue {
3027 model_handle : JoinHandle < ( ) > ,
3128}
3229
33- #[ derive( Clone ) ]
34- pub struct ValueCheckSender {
35- pub tx_cost_model : Sender < CostModelSource > ,
36- }
37-
38- pub struct ValueCheckReceiver {
39- rx_cost_model : Receiver < CostModelSource > ,
40- }
30+ impl MinimumValue {
31+ pub async fn new ( pgpool : PgPool ) -> Self {
32+ let cost_model_cache = Arc :: new ( Mutex :: new ( HashMap :: < DeploymentId , CostModelCache > :: new ( ) ) ) ;
4133
42- pub fn create_value_check ( size : usize ) -> ( ValueCheckSender , ValueCheckReceiver ) {
43- let ( tx_cost_model, rx_cost_model) = tokio:: sync:: mpsc:: channel ( size) ;
34+ let mut pglistener = PgListener :: connect_with ( & pgpool. clone ( ) ) . await . unwrap ( ) ;
35+ pglistener. listen ( "cost_models_update_notify" ) . await . expect (
36+ "should be able to subscribe to Postgres Notify events on the channel \
37+ 'cost_models_update_notify'",
38+ ) ;
4439
45- (
46- ValueCheckSender { tx_cost_model } ,
47- ValueCheckReceiver { rx_cost_model } ,
48- )
49- }
40+ // TODO start watcher
41+ let cancel_token = tokio_util:: sync:: CancellationToken :: new ( ) ;
5042
51- impl MinimumValue {
52- pub fn new ( ValueCheckReceiver { mut rx_cost_model } : ValueCheckReceiver ) -> Self {
53- let cost_model_cache = Arc :: new ( Mutex :: new ( HashMap :: < DeploymentId , CostModelCache > :: new ( ) ) ) ;
54- let cache = cost_model_cache. clone ( ) ;
55- let model_handle = tokio:: spawn ( async move {
56- loop {
57- let model = rx_cost_model. recv ( ) . await ;
58- match model {
59- Some ( value) => {
60- let deployment_id = value. deployment_id ;
61-
62- if let Some ( query) = cache. lock ( ) . unwrap ( ) . get_mut ( & deployment_id) {
63- let _ = query. insert_model ( value) . inspect_err ( |err| {
64- tracing:: error!(
65- "Error while compiling cost model for deployment id {}. Error: {}" ,
66- deployment_id, err
67- )
68- } ) ;
69- } else {
70- match CostModelCache :: new ( value) {
71- Ok ( value) => {
72- cache. lock ( ) . unwrap ( ) . insert ( deployment_id, value) ;
73- }
74- Err ( err) => {
75- tracing:: error!(
76- "Error while compiling cost model for deployment id {}. Error: {}" ,
77- deployment_id, err
78- )
79- }
80- }
81- }
82- }
83- None => break ,
84- }
85- }
86- } ) ;
43+ let model_handle = tokio:: spawn ( Self :: cost_models_watcher (
44+ pgpool. clone ( ) ,
45+ pglistener,
46+ cost_model_cache. clone ( ) ,
47+ cancel_token. clone ( ) ,
48+ ) ) ;
8749
8850 Self {
8951 cost_model_cache,
@@ -92,17 +54,11 @@ impl MinimumValue {
9254 }
9355
9456 async fn cost_models_watcher (
95- pgpool : PgPool ,
57+ _pgpool : PgPool ,
9658 mut pglistener : PgListener ,
97- denylist : Arc < Mutex < HashMap < DeploymentId , CostModelCache > > > ,
59+ cost_model_cache : Arc < Mutex < HashMap < DeploymentId , CostModelCache > > > ,
9860 cancel_token : tokio_util:: sync:: CancellationToken ,
9961 ) {
100- #[ derive( serde:: Deserialize ) ]
101- struct DenylistNotification {
102- tg_op : String ,
103- deployment : DeploymentId ,
104- }
105-
10662 loop {
10763 tokio:: select! {
10864 _ = cancel_token. cancelled( ) => {
@@ -112,39 +68,58 @@ impl MinimumValue {
11268 pg_notification = pglistener. recv( ) => {
11369 let pg_notification = pg_notification. expect(
11470 "should be able to receive Postgres Notify events on the channel \
115- 'scalar_tap_deny_notification '",
71+ 'cost_models_update_notify '",
11672 ) ;
11773
118- let denylist_notification : DenylistNotification =
74+ let cost_model_notification : CostModelNotification =
11975 serde_json:: from_str( pg_notification. payload( ) ) . expect(
12076 "should be able to deserialize the Postgres Notify event payload as a \
121- DenylistNotification ",
77+ CostModelNotification ",
12278 ) ;
12379
124- match denylist_notification. tg_op. as_str( ) {
80+ let deployment_id = cost_model_notification. deployment;
81+
82+ match cost_model_notification. tg_op. as_str( ) {
12583 "INSERT" => {
126- denylist
127- . write( )
128- . unwrap( )
129- . insert( denylist_notification. sender_address) ;
84+ let cost_model_source: CostModelSource = cost_model_notification. into( ) ;
85+ let mut cost_model_cache = cost_model_cache
86+ . lock( )
87+ . unwrap( ) ;
88+
89+ match cost_model_cache. get_mut( & deployment_id) {
90+ Some ( cache) => {
91+ let _ = cache. insert_model( cost_model_source) ;
92+ } ,
93+ None => {
94+ if let Ok ( cache) = CostModelCache :: new( cost_model_source) . inspect_err( |err| {
95+ tracing:: error!(
96+ "Error while compiling cost model for deployment id {}. Error: {}" ,
97+ deployment_id, err
98+ )
99+ } ) {
100+ cost_model_cache. insert( deployment_id, cache) ;
101+ }
102+ } ,
103+ }
130104 }
131105 "DELETE" => {
132- denylist
133- . write ( )
106+ cost_model_cache
107+ . lock ( )
134108 . unwrap( )
135- . remove( & denylist_notification . sender_address ) ;
109+ . remove( & cost_model_notification . deployment ) ;
136110 }
137- // UPDATE and TRUNCATE are not expected to happen. Reload the entire denylist.
111+ // UPDATE and TRUNCATE are not expected to happen. Reload the entire cost
112+ // model cache.
138113 _ => {
139114 error!(
140- "Received an unexpected denylist table notification: {}. Reloading entire \
141- denylist .",
142- denylist_notification . tg_op
115+ "Received an unexpected cost model table notification: {}. Reloading entire \
116+ cost model .",
117+ cost_model_notification . tg_op
143118 ) ;
144119
145- Self :: sender_denylist_reload( pgpool. clone( ) , denylist. clone( ) )
146- . await
147- . expect( "should be able to reload the sender denylist " )
120+ // Self::sender_denylist_reload(pgpool.clone(), denylist.clone())
121+ // .await
122+ // .expect("should be able to reload cost models ")
148123 }
149124 }
150125 }
@@ -229,6 +204,24 @@ pub struct CostModelSource {
229204 pub variables : String ,
230205}
231206
207+ #[ derive( serde:: Deserialize ) ]
208+ struct CostModelNotification {
209+ tg_op : String ,
210+ deployment : DeploymentId ,
211+ model : String ,
212+ variables : String ,
213+ }
214+
215+ impl From < CostModelNotification > for CostModelSource {
216+ fn from ( value : CostModelNotification ) -> Self {
217+ CostModelSource {
218+ deployment_id : value. deployment ,
219+ model : value. model ,
220+ variables : value. variables ,
221+ }
222+ }
223+ }
224+
232225pub struct CostModelCache {
233226 models : TtlCache < CostModelSource , CostModel > ,
234227 latest_model : CostModel ,
0 commit comments