1010use std:: {
1111 collections:: { BTreeSet , HashMap } ,
1212 ops:: Deref ,
13- sync:: Arc ,
13+ sync:: {
14+ atomic:: { AtomicUsize , Ordering } ,
15+ Arc ,
16+ } ,
17+ time:: Duration ,
1418} ;
1519
1620use async_lock:: { Semaphore , SemaphoreGuard } ;
@@ -23,17 +27,24 @@ use scylla::{
2327 session:: Session ,
2428 session_builder:: SessionBuilder ,
2529 } ,
30+ cluster:: { ClusterState , Node , NodeRef } ,
2631 deserialize:: { DeserializationError , TypeCheckError } ,
2732 errors:: {
28- DbError , ExecutionError , IntoRowsResultError , NewSessionError , NextPageError , NextRowError ,
29- PagerExecutionError , PrepareError , RequestAttemptError , RequestError , RowsError ,
33+ ClusterStateTokenError , DbError , ExecutionError , IntoRowsResultError , MetadataError ,
34+ NewSessionError , NextPageError , NextRowError , PagerExecutionError , PrepareError ,
35+ RequestAttemptError , RequestError , RowsError ,
3036 } ,
3137 policies:: {
32- load_balancing:: { DefaultPolicy , LoadBalancingPolicy } ,
38+ load_balancing:: { DefaultPolicy , FallbackPlan , LoadBalancingPolicy , RoutingInfo } ,
3339 retry:: DefaultRetryPolicy ,
3440 } ,
3541 response:: PagingState ,
36- statement:: { batch:: BatchType , prepared:: PreparedStatement , Consistency } ,
42+ routing:: { Shard , Token } ,
43+ statement:: {
44+ batch:: { Batch , BatchType } ,
45+ prepared:: PreparedStatement ,
46+ Consistency ,
47+ } ,
3748} ;
3849use serde:: { Deserialize , Serialize } ;
3950use thiserror:: Error ;
@@ -225,15 +236,17 @@ impl ScyllaDbClient {
225236 async fn build_default_session ( uri : & str ) -> Result < Session , ScyllaDbStoreInternalError > {
226237 // This explicitly sets a lot of default parameters for clarity and for making future changes
227238 // easier.
228- SessionBuilder :: new ( )
239+ let session = SessionBuilder :: new ( )
229240 . known_node ( uri)
241+ . cluster_metadata_refresh_interval ( Duration :: from_secs ( 10 ) )
230242 . default_execution_profile_handle ( Self :: build_default_execution_profile_handle (
231243 Self :: build_default_policy ( ) ,
232244 ) )
233245 . build ( )
234246 . boxed_sync ( )
235- . await
236- . map_err ( Into :: into)
247+ . await ?;
248+ session. refresh_metadata ( ) . await ?;
249+ Ok ( session)
237250 }
238251
239252 async fn get_multi_key_values_statement (
@@ -243,6 +256,7 @@ impl ScyllaDbClient {
243256 if let Some ( prepared_statement) = self . multi_key_values . get ( & num_markers) {
244257 return Ok ( prepared_statement. clone ( ) ) ;
245258 }
259+
246260 let markers = std:: iter:: repeat_n ( "?" , num_markers)
247261 . collect :: < Vec < _ > > ( )
248262 . join ( "," ) ;
@@ -265,6 +279,7 @@ impl ScyllaDbClient {
265279 if let Some ( prepared_statement) = self . multi_keys . get ( & num_markers) {
266280 return Ok ( prepared_statement. clone ( ) ) ;
267281 } ;
282+
268283 let markers = std:: iter:: repeat_n ( "?" , num_markers)
269284 . collect :: < Vec < _ > > ( )
270285 . join ( "," ) ;
@@ -405,46 +420,74 @@ impl ScyllaDbClient {
405420 Ok ( rows. next ( ) . is_some ( ) )
406421 }
407422
423+ fn get_sticky_shard_policy_or_default (
424+ & self ,
425+ partition_key : & [ u8 ] ,
426+ ) -> Arc < dyn LoadBalancingPolicy > {
427+ StickyShardPolicy :: new (
428+ & self . session ,
429+ & self . namespace ,
430+ partition_key,
431+ ScyllaDbClient :: build_default_policy ( ) ,
432+ )
433+ . map ( |policy| Arc :: new ( policy) as Arc < dyn LoadBalancingPolicy > )
434+ . unwrap_or_else ( |_| ScyllaDbClient :: build_default_policy ( ) )
435+ }
436+
437+ // Returns a batch query with a sticky shard policy, that always tries to route to the same
438+ // ScyllaDB shard.
439+ // Should be used only on batches where all statements are to the same partition key.
440+ fn get_sticky_batch_query (
441+ & self ,
442+ partition_key : & [ u8 ] ,
443+ ) -> Result < Batch , ScyllaDbStoreInternalError > {
444+ // Since we assume this is all to the same partition key, we can use an unlogged batch.
445+ // We could use a logged batch to get atomicity across different partitions, but that
446+ // comes with a huge performance penalty (seems to double write latency).
447+ let mut batch_query = Batch :: new ( BatchType :: Unlogged ) ;
448+ let policy = self . get_sticky_shard_policy_or_default ( partition_key) ;
449+ let handle = Self :: build_default_execution_profile_handle ( policy) ;
450+ batch_query. set_execution_profile_handle ( Some ( handle) ) ;
451+
452+ Ok ( batch_query)
453+ }
454+
455+ // Batches should be always to the same partition key. Batches across different partitions
456+ // will not be atomic. If the caller wants atomicity, it's the caller's responsibility to
457+ // make sure that the batch only has statements to the same partition key.
408458 async fn write_batch_internal (
409459 & self ,
410460 root_key : & [ u8 ] ,
411461 batch : UnorderedBatch ,
412462 ) -> Result < ( ) , ScyllaDbStoreInternalError > {
413- let session = & self . session ;
414- let mut batch_query = scylla:: statement:: batch:: Batch :: new ( BatchType :: Unlogged ) ;
415- let mut batch_values = Vec :: new ( ) ;
416- let query1 = & self . write_batch_delete_prefix_unbounded ;
417- let query2 = & self . write_batch_delete_prefix_bounded ;
418463 Self :: check_batch_len ( & batch) ?;
464+ let session = & self . session ;
465+ let mut batch_query = self . get_sticky_batch_query ( root_key) ?;
466+ let mut batch_values = Vec :: with_capacity ( batch. len ( ) ) ;
467+
419468 for key_prefix in batch. key_prefix_deletions {
420469 Self :: check_key_size ( & key_prefix) ?;
421470 match get_upper_bound_option ( & key_prefix) {
422471 None => {
423- let values = vec ! [ root_key. to_vec( ) , key_prefix] ;
424- batch_values. push ( values) ;
425- batch_query. append_statement ( query1. clone ( ) ) ;
472+ batch_query. append_statement ( self . write_batch_delete_prefix_unbounded . clone ( ) ) ;
473+ batch_values. push ( vec ! [ root_key. to_vec( ) , key_prefix] ) ;
426474 }
427475 Some ( upper_bound) => {
428- let values = vec ! [ root_key. to_vec( ) , key_prefix, upper_bound] ;
429- batch_values. push ( values) ;
430- batch_query. append_statement ( query2. clone ( ) ) ;
476+ batch_query. append_statement ( self . write_batch_delete_prefix_bounded . clone ( ) ) ;
477+ batch_values. push ( vec ! [ root_key. to_vec( ) , key_prefix, upper_bound] ) ;
431478 }
432479 }
433480 }
434- let query3 = & self . write_batch_deletion ;
435481 for key in batch. simple_unordered_batch . deletions {
436482 Self :: check_key_size ( & key) ?;
437- let values = vec ! [ root_key. to_vec( ) , key] ;
438- batch_values. push ( values) ;
439- batch_query. append_statement ( query3. clone ( ) ) ;
483+ batch_query. append_statement ( self . write_batch_deletion . clone ( ) ) ;
484+ batch_values. push ( vec ! [ root_key. to_vec( ) , key] ) ;
440485 }
441- let query4 = & self . write_batch_insertion ;
442486 for ( key, value) in batch. simple_unordered_batch . insertions {
443487 Self :: check_key_size ( & key) ?;
444488 Self :: check_value_size ( & value) ?;
445- let values = vec ! [ root_key. to_vec( ) , key, value] ;
446- batch_values. push ( values) ;
447- batch_query. append_statement ( query4. clone ( ) ) ;
489+ batch_query. append_statement ( self . write_batch_insertion . clone ( ) ) ;
490+ batch_values. push ( vec ! [ root_key. to_vec( ) , key, value] ) ;
448491 }
449492 session. batch ( & batch_query, batch_values) . await ?;
450493 Ok ( ( ) )
@@ -517,6 +560,69 @@ impl ScyllaDbClient {
517560 }
518561}
519562
563+ // Batch statements in ScyllaDb are currently not token aware. The batch gets sent to a random
564+ // node: https://rust-driver.docs.scylladb.com/stable/statements/batch.html#performance
565+ // However, for batches where all statements are to the same partition key, we can use a sticky
566+ // shard policy to route to the same shard, and make batches be token aware.
567+ //
568+ // This is a policy that always tries to route to the ScyllaDB shards that contain the token, in a
569+ // round-robin fashion.
570+ #[ derive( Debug ) ]
571+ struct StickyShardPolicy {
572+ replicas : Vec < ( Arc < Node > , Shard ) > ,
573+ current_replica_index : AtomicUsize ,
574+ fallback : Arc < dyn LoadBalancingPolicy > ,
575+ }
576+
577+ impl StickyShardPolicy {
578+ fn new (
579+ session : & Session ,
580+ namespace : & str ,
581+ partition_key : & [ u8 ] ,
582+ fallback : Arc < dyn LoadBalancingPolicy > ,
583+ ) -> Result < Self , ScyllaDbStoreInternalError > {
584+ let cluster = session. get_cluster_state ( ) ;
585+ let token = cluster. compute_token ( KEYSPACE , namespace, & ( partition_key, ) ) ?;
586+ let replicas = cluster. get_token_endpoints ( KEYSPACE , namespace, token) ;
587+ Ok ( Self {
588+ replicas,
589+ current_replica_index : AtomicUsize :: new ( 0 ) ,
590+ fallback,
591+ } )
592+ }
593+ }
594+
595+ impl LoadBalancingPolicy for StickyShardPolicy {
596+ fn name ( & self ) -> String {
597+ "StickyShardPolicy" . to_string ( )
598+ }
599+
600+ // Always try first to route to the sticky shard.
601+ fn pick < ' a > (
602+ & ' a self ,
603+ request : & ' a RoutingInfo < ' a > ,
604+ cluster : & ' a ClusterState ,
605+ ) -> Option < ( NodeRef < ' a > , Option < Shard > ) > {
606+ if self . replicas . is_empty ( ) {
607+ return self . fallback . pick ( request, cluster) ;
608+ }
609+ // fetch_add will wrap around on overflow, so we should be ok just incrementing forever here.
610+ let new_replica_index =
611+ self . current_replica_index . fetch_add ( 1 , Ordering :: Relaxed ) % self . replicas . len ( ) ;
612+ let ( node, shard) = & self . replicas [ new_replica_index] ;
613+ Some ( ( node, Some ( * shard) ) )
614+ }
615+
616+ // Fallback to the default policy.
617+ fn fallback < ' a > (
618+ & ' a self ,
619+ request : & ' a RoutingInfo ,
620+ cluster : & ' a ClusterState ,
621+ ) -> FallbackPlan < ' a > {
622+ self . fallback . fallback ( request, cluster)
623+ }
624+ }
625+
520626/// The client itself and the keeping of the count of active connections.
521627#[ derive( Clone ) ]
522628pub struct ScyllaDbStoreInternal {
@@ -588,6 +694,22 @@ pub enum ScyllaDbStoreInternalError {
588694 /// A next row error in ScyllaDB
589695 #[ error( transparent) ]
590696 NextRowError ( #[ from] NextRowError ) ,
697+
698+ /// A token error in ScyllaDB
699+ #[ error( transparent) ]
700+ ClusterStateTokenError ( #[ from] ClusterStateTokenError ) ,
701+
702+ /// The token endpoint information is currently missing from the driver
703+ #[ error( "The token endpoint information is currently missing from the driver" ) ]
704+ MissingTokenEndpoints ( Token ) ,
705+
706+ /// The mutex is poisoned
707+ #[ error( "The mutex is poisoned" ) ]
708+ PoisonedMutex ,
709+
710+ /// A metadata error in ScyllaDB
711+ #[ error( transparent) ]
712+ MetadataError ( #[ from] MetadataError ) ,
591713}
592714
593715impl KeyValueStoreError for ScyllaDbStoreInternalError {
@@ -699,6 +821,9 @@ impl DirectWritableKeyValueStore for ScyllaDbStoreInternal {
699821 // https://github.com/scylladb/scylladb/blob/master/docs/dev/timestamp-conflict-resolution.md
700822 type Batch = UnorderedBatch ;
701823
824+ // Batches should be always to the same partition key. Batches across different partitions
825+ // will not be atomic. If the caller wants atomicity, it's the caller's responsibility to
826+ // make sure that the batch only has statements to the same partition key.
702827 async fn write_batch ( & self , batch : Self :: Batch ) -> Result < ( ) , ScyllaDbStoreInternalError > {
703828 let store = self . store . deref ( ) ;
704829 let _guard = self . acquire ( ) . await ;
0 commit comments