Skip to content

Commit ac58615

Browse files
committed
Initial shard aware driver
1 parent fdfcdf5 commit ac58615

File tree

5 files changed

+111
-31
lines changed

5 files changed

+111
-31
lines changed

cassandra/cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4384,7 +4384,7 @@ def _query(self, host, message=None, cb=None):
43844384
connection = None
43854385
try:
43864386
# TODO get connectTimeout from cluster settings
4387-
connection, request_id = pool.borrow_connection(timeout=2.0)
4387+
connection, request_id = pool.borrow_connection(timeout=2.0, routing_key=self.query.routing_key if self.query else None)
43884388
self._connection = connection
43894389
result_meta = self.prepared_statement.result_metadata if self.prepared_statement else []
43904390

cassandra/connection.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
RegisterMessage, ReviseRequestMessage)
4545
from cassandra.util import OrderedDict
4646

47+
MIN_LONG = -(2 ** 63)
4748

4849
log = logging.getLogger(__name__)
4950

@@ -599,6 +600,39 @@ def int_from_buf_item(i):
599600
else:
600601
int_from_buf_item = ord
601602

603+
class ShardingInfo(object):
604+
605+
def __init__(self, shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb):
606+
self.shards_count = int(shards_count)
607+
self.partitioner = partitioner
608+
self.sharding_algorithm = sharding_algorithm
609+
self.sharding_ignore_msb = int(sharding_ignore_msb)
610+
611+
@staticmethod
612+
def parse_sharding_info(message):
613+
shard_id = message.options.get('SCYLLA_SHARD', [''])[0] or None
614+
shards_count = message.options.get('SCYLLA_NR_SHARDS', [''])[0] or None
615+
partitioner = message.options.get('SCYLLA_PARTITIONER', [''])[0] or None
616+
sharding_algorithm = message.options.get('SCYLLA_SHARDING_ALGORITHM', [''])[0] or None
617+
sharding_ignore_msb = message.options.get('SCYLLA_SHARDING_IGNORE_MSB', [''])[0] or None
618+
619+
if not (shard_id or shards_count or partitioner == "org.apache.cassandra.dht.Murmur3Partitioner" or
620+
sharding_algorithm == "biased-token-round-robin" or sharding_ignore_msb):
621+
return 0, None
622+
623+
return int(shard_id), ShardingInfo(shard_id, shards_count, partitioner, sharding_algorithm, sharding_ignore_msb)
624+
625+
def shard_id(self, t):
626+
token = t.value
627+
token += MIN_LONG
628+
token <<= self.sharding_ignore_msb
629+
tokLo = token & 0xffffffff
630+
tokHi = (token >> 32) & 0xffffffff
631+
mul1 = tokLo * self.shards_count
632+
mul2 = tokHi * self.shards_count
633+
_sum = (mul1 >> 32) + mul2
634+
output = _sum >> 32
635+
return output
602636

603637
class Connection(object):
604638

@@ -666,6 +700,9 @@ class Connection(object):
666700
_check_hostname = False
667701
_product_type = None
668702

703+
shard_id = 0
704+
sharding_info = None
705+
669706
def __init__(self, host='127.0.0.1', port=9042, authenticator=None,
670707
ssl_options=None, sockopts=None, compression=True,
671708
cql_version=None, protocol_version=ProtocolVersion.MAX_SUPPORTED, is_control_connection=False,
@@ -1126,6 +1163,7 @@ def _send_options_message(self):
11261163

11271164
@defunct_on_error
11281165
def _handle_options_response(self, options_response):
1166+
self.shard_id, self.sharding_info = ShardingInfo.parse_sharding_info(options_response)
11291167
if self.is_defunct:
11301168
return
11311169

cassandra/pool.py

Lines changed: 65 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import logging
2121
import socket
2222
import time
23+
import random
2324
from threading import Lock, RLock, Condition
2425
import weakref
2526
try:
@@ -123,6 +124,8 @@ class Host(object):
123124

124125
_currently_handling_node_up = False
125126

127+
sharding_info = None
128+
126129
def __init__(self, endpoint, conviction_policy_factory, datacenter=None, rack=None, host_id=None):
127130
if endpoint is None:
128131
raise ValueError("endpoint may not be None")
@@ -339,7 +342,6 @@ class HostConnection(object):
339342
shutdown_on_error = False
340343

341344
_session = None
342-
_connection = None
343345
_lock = None
344346
_keyspace = None
345347

@@ -351,6 +353,7 @@ def __init__(self, host, host_distance, session):
351353
# this is used in conjunction with the connection streams. Not using the connection lock because the connection can be replaced in the lifetime of the pool.
352354
self._stream_available_condition = Condition(self._lock)
353355
self._is_replacing = False
356+
self._connections = dict()
354357

355358
if host_distance == HostDistance.IGNORED:
356359
log.debug("Not opening connection to ignored host %s", self.host)
@@ -360,18 +363,45 @@ def __init__(self, host, host_distance, session):
360363
return
361364

362365
log.debug("Initializing connection for host %s", self.host)
363-
self._connection = session.cluster.connection_factory(host.endpoint)
366+
first_connection = session.cluster.connection_factory(host.endpoint)
367+
log.debug("first connection created for shard_id=%i", first_connection.shard_id)
368+
self._connections[first_connection.shard_id] = first_connection
364369
self._keyspace = session.keyspace
370+
365371
if self._keyspace:
366-
self._connection.set_keyspace_blocking(self._keyspace)
372+
first_connection.set_keyspace_blocking(self._keyspace)
373+
374+
if first_connection.sharding_info:
375+
self.host.sharding_info = weakref.proxy(first_connection.sharding_info)
376+
for _ in range(first_connection.sharding_info.shards_count * 2):
377+
conn = self._session.cluster.connection_factory(self.host.endpoint)
378+
if conn.shard_id not in self._connections.keys():
379+
log.debug("new connection created for shard_id=%i", conn.shard_id)
380+
self._connections[conn.shard_id] = conn
381+
if self._keyspace:
382+
self._connections[conn.shard_id].set_keyspace_blocking(self._keyspace)
383+
384+
if len(self._connections.keys()) == first_connection.sharding_info.shards_count:
385+
break
386+
if not len(self._connections.keys()) == first_connection.sharding_info.shards_count:
387+
raise NoConnectionsAvailable("not enough shard connection opened")
388+
367389
log.debug("Finished initializing connection for host %s", self.host)
368390

369-
def borrow_connection(self, timeout):
391+
def borrow_connection(self, timeout, routing_key=None):
370392
if self.is_shutdown:
371393
raise ConnectionException(
372394
"Pool for %s is shutdown" % (self.host,), self.host)
373395

374-
conn = self._connection
396+
shard_id = 0
397+
if self.host.sharding_info:
398+
if routing_key:
399+
t = self._session.cluster.metadata.token_map.token_class.from_key(routing_key)
400+
shard_id =self.host.sharding_info.shard_id(t)
401+
else:
402+
shard_id = random.randint(0, self.host.sharding_info.shards_count - 1)
403+
404+
conn = self._connections.get(shard_id)
375405
if not conn:
376406
raise NoConnectionsAvailable()
377407

@@ -416,7 +446,7 @@ def return_connection(self, connection):
416446
if is_down:
417447
self.shutdown()
418448
else:
419-
self._connection = None
449+
del self._connections[connection.shard_id]
420450
with self._lock:
421451
if self._is_replacing:
422452
return
@@ -433,7 +463,7 @@ def _replace(self, connection):
433463
conn = self._session.cluster.connection_factory(self.host.endpoint)
434464
if self._keyspace:
435465
conn.set_keyspace_blocking(self._keyspace)
436-
self._connection = conn
466+
self._connections[connection.shard_id] = conn
437467
except Exception:
438468
log.warning("Failed reconnecting %s. Retrying." % (self.host.endpoint,))
439469
self._session.submit(self._replace, connection)
@@ -450,36 +480,48 @@ def shutdown(self):
450480
self.is_shutdown = True
451481
self._stream_available_condition.notify_all()
452482

453-
if self._connection:
454-
self._connection.close()
455-
self._connection = None
483+
if self._connections:
484+
for c in self._connections.values():
485+
c.close()
486+
self._connections = dict()
456487

457488
def _set_keyspace_for_all_conns(self, keyspace, callback):
458-
if self.is_shutdown or not self._connection:
489+
"""
490+
Asynchronously sets the keyspace for all connections. When all
491+
connections have been set, `callback` will be called with two
492+
arguments: this pool, and a list of any errors that occurred.
493+
"""
494+
remaining_callbacks = set(self._connections.values())
495+
errors = []
496+
497+
if not remaining_callbacks:
498+
callback(self, errors)
459499
return
460500

461501
def connection_finished_setting_keyspace(conn, error):
462502
self.return_connection(conn)
463-
errors = [] if not error else [error]
464-
callback(self, errors)
503+
remaining_callbacks.remove(conn)
504+
if error:
505+
errors.append(error)
506+
507+
if not remaining_callbacks:
508+
callback(self, errors)
465509

466510
self._keyspace = keyspace
467-
self._connection.set_keyspace_async(keyspace, connection_finished_setting_keyspace)
511+
for conn in self._connections.values():
512+
conn.set_keyspace_async(keyspace, connection_finished_setting_keyspace)
468513

469514
def get_connections(self):
470-
c = self._connection
471-
return [c] if c else []
515+
c = self._connections
516+
return list(self._connections.values()) if c else []
472517

473518
def get_state(self):
474-
connection = self._connection
475-
open_count = 1 if connection and not (connection.is_closed or connection.is_defunct) else 0
476-
in_flights = [connection.in_flight] if connection else []
477-
return {'shutdown': self.is_shutdown, 'open_count': open_count, 'in_flights': in_flights}
519+
in_flights = [c.in_flight for c in self._connections.values()]
520+
return {'shutdown': self.is_shutdown, 'open_count': self.open_count, 'in_flights': in_flights}
478521

479522
@property
480523
def open_count(self):
481-
connection = self._connection
482-
return 1 if connection and not (connection.is_closed or connection.is_defunct) else 0
524+
return sum([1 if c and not (c.is_closed or c.is_defunct) else 0 for c in self._connections.values()])
483525

484526
_MAX_SIMULTANEOUS_CREATION = 1
485527
_MIN_TRASH_INTERVAL = 10
@@ -522,7 +564,7 @@ def __init__(self, host, host_distance, session):
522564
self.open_count = core_conns
523565
log.debug("Finished initializing new connection pool for host %s", self.host)
524566

525-
def borrow_connection(self, timeout):
567+
def borrow_connection(self, timeout, routing_key=None):
526568
if self.is_shutdown:
527569
raise ConnectionException(
528570
"Pool for %s is shutdown" % (self.host,), self.host)

tests/integration/standard/test_connection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,8 +172,8 @@ def fetch_connections(self, host, cluster):
172172
if conn._connections is not None and len(conn._connections) > 0:
173173
connections.append(conn._connections)
174174
else:
175-
if conn._connection is not None:
176-
connections.append(conn._connection)
175+
if conn._connections and len(conn._connections.values()) > 0:
176+
connections.append(conn._connections.values())
177177
return connections
178178

179179
def wait_for_connections(self, host, cluster):

tests/unit/test_response_future.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def test_result_message(self):
7474
rf.send_request()
7575

7676
rf.session._pools.get.assert_called_once_with('ip1')
77-
pool.borrow_connection.assert_called_once_with(timeout=ANY)
77+
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY)
7878

7979
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])
8080

@@ -256,7 +256,7 @@ def test_retry_policy_says_retry(self):
256256
rf.send_request()
257257

258258
rf.session._pools.get.assert_called_once_with('ip1')
259-
pool.borrow_connection.assert_called_once_with(timeout=ANY)
259+
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY)
260260
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])
261261

262262
result = Mock(spec=UnavailableErrorMessage, info={})
@@ -275,7 +275,7 @@ def test_retry_policy_says_retry(self):
275275
# it should try again with the same host since this was
276276
# an UnavailableException
277277
rf.session._pools.get.assert_called_with(host)
278-
pool.borrow_connection.assert_called_with(timeout=ANY)
278+
pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY)
279279
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])
280280

281281
def test_retry_with_different_host(self):
@@ -290,7 +290,7 @@ def test_retry_with_different_host(self):
290290
rf.send_request()
291291

292292
rf.session._pools.get.assert_called_once_with('ip1')
293-
pool.borrow_connection.assert_called_once_with(timeout=ANY)
293+
pool.borrow_connection.assert_called_once_with(timeout=ANY, routing_key=ANY)
294294
connection.send_msg.assert_called_once_with(rf.message, 1, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])
295295
self.assertEqual(ConsistencyLevel.QUORUM, rf.message.consistency_level)
296296

@@ -309,7 +309,7 @@ def test_retry_with_different_host(self):
309309

310310
# it should try with a different host
311311
rf.session._pools.get.assert_called_with('ip2')
312-
pool.borrow_connection.assert_called_with(timeout=ANY)
312+
pool.borrow_connection.assert_called_with(timeout=ANY, routing_key=ANY)
313313
connection.send_msg.assert_called_with(rf.message, 2, cb=ANY, encoder=ProtocolHandler.encode_message, decoder=ProtocolHandler.decode_message, result_metadata=[])
314314

315315
# the consistency level should be the same

0 commit comments

Comments
 (0)