11use core:: hash:: Hash ;
22use std:: {
3+ collections:: HashMap ,
34 fmt:: Debug ,
45 sync:: Arc ,
56} ;
@@ -68,7 +69,8 @@ pub struct AsyncLru<RT: Runtime, Key, Value> {
6869 pause_client : Option < Arc < tokio:: sync:: Mutex < PauseClient > > > ,
6970}
7071
71- pub type ValueGenerator < Value > = BoxFuture < ' static , anyhow:: Result < Value > > ;
72+ pub type SingleValueGenerator < Value > = BoxFuture < ' static , anyhow:: Result < Value > > ;
73+ pub type ValueGenerator < Key , Value > = BoxFuture < ' static , HashMap < Key , anyhow:: Result < Value > > > ;
7274
7375impl < RT : Runtime , Key , Value > Clone for AsyncLru < RT , Key , Value > {
7476 fn clone ( & self ) -> Self {
@@ -142,7 +144,7 @@ type BuildValueResult<Value> = Result<Arc<Value>, Arc<anyhow::Error>>;
142144
143145type BuildValueRequest < Key , Value > = (
144146 Key ,
145- ValueGenerator < Value > ,
147+ ValueGenerator < Key , Value > ,
146148 async_broadcast:: Sender < BuildValueResult < Value > > ,
147149) ;
148150
@@ -247,8 +249,8 @@ impl<
247249 inner. current_size += new_value. size ( ) ;
248250 // Ideally we'd not change the LRU order by putting here...
249251 if let Some ( old_value) = inner. cache . put ( key, new_value) {
250- anyhow :: ensure! ( !matches! ( old_value , CacheResult :: Ready { .. } ) ) ;
251- // Just in case we ever assign a size to our Waiting entries .
252+ // Allow overwriting entries (Waiting or Ready) which may have been populated
253+ // by racing requests with prefetches .
252254 inner. current_size -= old_value. size ( ) ;
253255 }
254256 Self :: trim_to_size ( & mut inner) ;
@@ -300,21 +302,45 @@ impl<
300302 inner. current_size
301303 }
302304
303- pub async fn get (
305+ pub async fn get_and_prepopulate (
304306 & self ,
305307 key : Key ,
306- value_generator : ValueGenerator < Value > ,
308+ value_generator : ValueGenerator < Key , Value > ,
307309 ) -> anyhow:: Result < Arc < Value > > {
308310 let timer = async_lru_get_timer ( self . label ) ;
309311 let result = self . _get ( & key, value_generator) . await ;
310312 timer. finish ( result. is_ok ( ) ) ;
311313 result
312314 }
313315
316+ pub async fn get (
317+ & self ,
318+ key : Key ,
319+ value_generator : SingleValueGenerator < Value > ,
320+ ) -> anyhow:: Result < Arc < Value > >
321+ where
322+ Key : Clone ,
323+ {
324+ let timer = async_lru_get_timer ( self . label ) ;
325+ let key_ = key. clone ( ) ;
326+ let result = self
327+ . _get (
328+ & key_,
329+ Box :: pin ( async move {
330+ let mut hashmap = HashMap :: new ( ) ;
331+ hashmap. insert ( key, value_generator. await ) ;
332+ hashmap
333+ } ) ,
334+ )
335+ . await ;
336+ timer. finish ( result. is_ok ( ) ) ;
337+ result
338+ }
339+
314340 async fn _get (
315341 & self ,
316342 key : & Key ,
317- value_generator : ValueGenerator < Value > ,
343+ value_generator : ValueGenerator < Key , Value > ,
318344 ) -> anyhow:: Result < Arc < Value > > {
319345 match self . get_sync ( key, value_generator) ? {
320346 Status :: Ready ( value) => Ok ( value) ,
@@ -336,7 +362,7 @@ impl<
336362 fn get_sync (
337363 & self ,
338364 key : & Key ,
339- value_generator : ValueGenerator < Value > ,
365+ value_generator : ValueGenerator < Key , Value > ,
340366 ) -> anyhow:: Result < Status < Value > > {
341367 let mut inner = self . inner . lock ( ) ;
342368 log_async_lru_size ( inner. cache . len ( ) , inner. current_size , self . label ) ;
@@ -407,10 +433,16 @@ impl<
407433 return ;
408434 }
409435
410- let value = generator. await ;
436+ let values = generator. await ;
411437
412- let to_broadcast = Self :: update_value ( rt, inner, key, value) . map_err ( Arc :: new) ;
413- let _ = tx. broadcast ( to_broadcast) . await ;
438+ for ( k, value) in values {
439+ let is_requested_key = k == key;
440+ let to_broadcast =
441+ Self :: update_value ( rt. clone ( ) , inner. clone ( ) , k, value) . map_err ( Arc :: new) ;
442+ if is_requested_key {
443+ let _ = tx. broadcast ( to_broadcast) . await ;
444+ }
445+ }
414446 }
415447 } )
416448 . await ;
@@ -420,7 +452,10 @@ impl<
420452
421453#[ cfg( test) ]
422454mod tests {
423- use std:: sync:: Arc ;
455+ use std:: {
456+ collections:: HashMap ,
457+ sync:: Arc ,
458+ } ;
424459
425460 use common:: {
426461 pause:: PauseController ,
@@ -536,6 +571,36 @@ mod tests {
536571 Ok ( ( ) )
537572 }
538573
574+ #[ convex_macro:: test_runtime]
575+ async fn test_get_and_prepopulate ( rt : TestRuntime ) -> anyhow:: Result < ( ) > {
576+ let cache = AsyncLru :: new ( rt, 10 , 1 , "label" ) ;
577+ let first = cache
578+ . get_and_prepopulate (
579+ "k1" ,
580+ async move {
581+ let mut hashmap = HashMap :: new ( ) ;
582+ hashmap. insert ( "k1" , Ok ( 1 ) ) ;
583+ hashmap. insert ( "k2" , Ok ( 2 ) ) ;
584+ hashmap. insert ( "k3" , Err ( anyhow:: anyhow!( "k3 failed" ) ) ) ;
585+ hashmap
586+ }
587+ . boxed ( ) ,
588+ )
589+ . await ?;
590+ assert_eq ! ( * first, 1 ) ;
591+ let k1_again = cache
592+ . get ( "k1" , GenerateRandomValue :: generate_value ( "k1" ) . boxed ( ) )
593+ . await ?;
594+ assert_eq ! ( * k1_again, 1 ) ;
595+ let k2_prepopulated = cache
596+ . get ( "k2" , GenerateRandomValue :: generate_value ( "k2" ) . boxed ( ) )
597+ . await ?;
598+ assert_eq ! ( * k2_prepopulated, 2 ) ;
599+ let k3_prepopulated = cache. get ( "k3" , async move { Ok ( 3 ) } . boxed ( ) ) . await ?;
600+ assert_eq ! ( * k3_prepopulated, 3 ) ;
601+ Ok ( ( ) )
602+ }
603+
539604 #[ convex_macro:: test_runtime]
540605 async fn get_generates_new_value_after_eviction ( rt : TestRuntime ) -> anyhow:: Result < ( ) > {
541606 let cache = AsyncLru :: new ( rt, 1 , 1 , "label" ) ;
0 commit comments