20
20
import logging
21
21
import socket
22
22
import time
23
+ import random
23
24
from threading import Lock , RLock , Condition
24
25
import weakref
25
26
try :
@@ -123,6 +124,8 @@ class Host(object):
123
124
124
125
_currently_handling_node_up = False
125
126
127
+ sharding_info = None
128
+
126
129
def __init__ (self , endpoint , conviction_policy_factory , datacenter = None , rack = None , host_id = None ):
127
130
if endpoint is None :
128
131
raise ValueError ("endpoint may not be None" )
@@ -339,7 +342,6 @@ class HostConnection(object):
339
342
shutdown_on_error = False
340
343
341
344
_session = None
342
- _connection = None
343
345
_lock = None
344
346
_keyspace = None
345
347
@@ -351,6 +353,7 @@ def __init__(self, host, host_distance, session):
351
353
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
352
354
self ._stream_available_condition = Condition (self ._lock )
353
355
self ._is_replacing = False
356
+ self ._connections = dict ()
354
357
355
358
if host_distance == HostDistance .IGNORED :
356
359
log .debug ("Not opening connection to ignored host %s" , self .host )
@@ -360,18 +363,45 @@ def __init__(self, host, host_distance, session):
360
363
return
361
364
362
365
log .debug ("Initializing connection for host %s" , self .host )
363
- self ._connection = session .cluster .connection_factory (host .endpoint )
366
+ first_connection = session .cluster .connection_factory (host .endpoint )
367
+ log .debug ("first connection created for shard_id=%i" , first_connection .shard_id )
368
+ self ._connections [first_connection .shard_id ] = first_connection
364
369
self ._keyspace = session .keyspace
370
+
365
371
if self ._keyspace :
366
- self ._connection .set_keyspace_blocking (self ._keyspace )
372
+ first_connection .set_keyspace_blocking (self ._keyspace )
373
+
374
+ if first_connection .sharding_info :
375
+ self .host .sharding_info = weakref .proxy (first_connection .sharding_info )
376
+ for _ in range (first_connection .sharding_info .shards_count * 2 ):
377
+ conn = self ._session .cluster .connection_factory (self .host .endpoint )
378
+ if conn .shard_id not in self ._connections .keys ():
379
+ log .debug ("new connection created for shard_id=%i" , conn .shard_id )
380
+ self ._connections [conn .shard_id ] = conn
381
+ if self ._keyspace :
382
+ self ._connections [conn .shard_id ].set_keyspace_blocking (self ._keyspace )
383
+
384
+ if len (self ._connections .keys ()) == first_connection .sharding_info .shards_count :
385
+ break
386
+ if not len (self ._connections .keys ()) == first_connection .sharding_info .shards_count :
387
+ raise NoConnectionsAvailable ("not enough shard connection opened" )
388
+
367
389
log .debug ("Finished initializing connection for host %s" , self .host )
368
390
369
- def borrow_connection (self , timeout ):
391
+ def borrow_connection (self , timeout , routing_key = None ):
370
392
if self .is_shutdown :
371
393
raise ConnectionException (
372
394
"Pool for %s is shutdown" % (self .host ,), self .host )
373
395
374
- conn = self ._connection
396
+ shard_id = 0
397
+ if self .host .sharding_info :
398
+ if routing_key :
399
+ t = self ._session .cluster .metadata .token_map .token_class .from_key (routing_key )
400
+ shard_id = self .host .sharding_info .shard_id (t )
401
+ else :
402
+ shard_id = random .randint (0 , self .host .sharding_info .shards_count - 1 )
403
+
404
+ conn = self ._connections .get (shard_id )
375
405
if not conn :
376
406
raise NoConnectionsAvailable ()
377
407
@@ -416,7 +446,7 @@ def return_connection(self, connection):
416
446
if is_down :
417
447
self .shutdown ()
418
448
else :
419
- self ._connection = None
449
+ del self ._connections [ connection . shard_id ]
420
450
with self ._lock :
421
451
if self ._is_replacing :
422
452
return
@@ -433,7 +463,7 @@ def _replace(self, connection):
433
463
conn = self ._session .cluster .connection_factory (self .host .endpoint )
434
464
if self ._keyspace :
435
465
conn .set_keyspace_blocking (self ._keyspace )
436
- self ._connection = conn
466
+ self ._connections [ connection . shard_id ] = conn
437
467
except Exception :
438
468
log .warning ("Failed reconnecting %s. Retrying." % (self .host .endpoint ,))
439
469
self ._session .submit (self ._replace , connection )
@@ -450,36 +480,48 @@ def shutdown(self):
450
480
self .is_shutdown = True
451
481
self ._stream_available_condition .notify_all ()
452
482
453
- if self ._connection :
454
- self ._connection .close ()
455
- self ._connection = None
483
+ if self ._connections :
484
+ for c in self ._connections .values ():
485
+ c .close ()
486
+ self ._connections = dict ()
456
487
457
488
def _set_keyspace_for_all_conns (self , keyspace , callback ):
458
- if self .is_shutdown or not self ._connection :
489
+ """
490
+ Asynchronously sets the keyspace for all connections. When all
491
+ connections have been set, `callback` will be called with two
492
+ arguments: this pool, and a list of any errors that occurred.
493
+ """
494
+ remaining_callbacks = set (self ._connections .values ())
495
+ errors = []
496
+
497
+ if not remaining_callbacks :
498
+ callback (self , errors )
459
499
return
460
500
461
501
def connection_finished_setting_keyspace (conn , error ):
462
502
self .return_connection (conn )
463
- errors = [] if not error else [error ]
464
- callback (self , errors )
503
+ remaining_callbacks .remove (conn )
504
+ if error :
505
+ errors .append (error )
506
+
507
+ if not remaining_callbacks :
508
+ callback (self , errors )
465
509
466
510
self ._keyspace = keyspace
467
- self ._connection .set_keyspace_async (keyspace , connection_finished_setting_keyspace )
511
+ for conn in self ._connections .values ():
512
+ conn .set_keyspace_async (keyspace , connection_finished_setting_keyspace )
468
513
469
514
def get_connections (self ):
470
- c = self ._connection
471
- return [ c ] if c else []
515
+ c = self ._connections
516
+ return list ( self . _connections . values ()) if c else []
472
517
473
518
def get_state (self ):
474
- connection = self ._connection
475
- open_count = 1 if connection and not (connection .is_closed or connection .is_defunct ) else 0
476
- in_flights = [connection .in_flight ] if connection else []
477
- return {'shutdown' : self .is_shutdown , 'open_count' : open_count , 'in_flights' : in_flights }
519
+ in_flights = [c .in_flight for c in self ._connections .values ()]
520
+ return {'shutdown' : self .is_shutdown , 'open_count' : self .open_count , 'in_flights' : in_flights }
478
521
479
522
@property
480
523
def open_count (self ):
481
- connection = self ._connection
482
- return 1 if connection and not (connection .is_closed or connection .is_defunct ) else 0
524
+ return sum ([1 if c and not (c .is_closed or c .is_defunct ) else 0 for c in self ._connections .values ()])
483
525
484
526
_MAX_SIMULTANEOUS_CREATION = 1
485
527
_MIN_TRASH_INTERVAL = 10
@@ -522,7 +564,7 @@ def __init__(self, host, host_distance, session):
522
564
self .open_count = core_conns
523
565
log .debug ("Finished initializing new connection pool for host %s" , self .host )
524
566
525
- def borrow_connection (self , timeout ):
567
+ def borrow_connection (self , timeout , routing_key = None ):
526
568
if self .is_shutdown :
527
569
raise ConnectionException (
528
570
"Pool for %s is shutdown" % (self .host ,), self .host )
0 commit comments