Skip to content

Commit 16a903d

Browse files
committed
Add unit/integration tests for shard aware
1 parent ac58615 commit 16a903d

File tree

2 files changed

+163
-0
lines changed

2 files changed

+163
-0
lines changed
Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
# Copyright DataStax, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import os
15+
16+
try:
17+
import unittest2 as unittest
18+
except ImportError:
19+
import unittest # noqa
20+
21+
from cassandra.cluster import Cluster
22+
from cassandra.policies import TokenAwarePolicy, RoundRobinPolicy
23+
24+
from tests.integration import use_cluster
25+
26+
27+
def setup_module():
28+
os.environ['SCYLLA_EXT_OPTS'] = "--smp 4 --memory 2048M"
29+
use_cluster('shared_aware', [1], start=True)
30+
31+
32+
class TestShardAwareIntegration(unittest.TestCase):
33+
@classmethod
34+
def setup_class(cls):
35+
cls.cluster = Cluster(protocol_version=4, load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()))
36+
cls.session = cls.cluster.connect()
37+
38+
@classmethod
39+
def teardown_class(cls):
40+
cls.cluster.shutdown()
41+
42+
def verify_same_shard_in_tracing(self, results, shard_name):
43+
trace_id = results.response_future.get_query_trace_ids()[0]
44+
traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id, ))
45+
events = [event for event in traces]
46+
for event in events:
47+
print(event.thread, event.activity)
48+
for event in events:
49+
self.assertEqual(event.thread, shard_name)
50+
self.assertIn('querying locally', "\n".join([event.activity for event in events]))
51+
52+
def test_all_tracing_coming_one_shard(self):
53+
"""
54+
Testing that shard aware driver is sending the requests to the correct shards
55+
56+
using the traces to validate that all the action been executed on the the same shard.
57+
this test is using prepared SELECT statements for this validation
58+
"""
59+
60+
self.session.execute(
61+
"""
62+
DROP KEYSPACE IF EXISTS preparedtests
63+
"""
64+
)
65+
self.session.execute(
66+
"""
67+
CREATE KEYSPACE preparedtests
68+
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
69+
""")
70+
71+
self.session.execute("USE preparedtests")
72+
self.session.execute(
73+
"""
74+
CREATE TABLE cf0 (
75+
a text,
76+
b text,
77+
c text,
78+
PRIMARY KEY (a, b)
79+
)
80+
""")
81+
82+
prepared = self.session.prepare(
83+
"""
84+
INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?)
85+
""")
86+
87+
bound = prepared.bind(('a', 'b', 'c'))
88+
89+
self.session.execute(bound)
90+
91+
bound = prepared.bind(('e', 'f', 'g'))
92+
93+
self.session.execute(bound)
94+
95+
bound = prepared.bind(('100000', 'f', 'g'))
96+
97+
self.session.execute(bound)
98+
99+
prepared = self.session.prepare(
100+
"""
101+
SELECT * FROM cf0 WHERE a=? AND b=?
102+
""")
103+
104+
bound = prepared.bind(('a', 'b'))
105+
results = self.session.execute(bound, trace=True)
106+
self.assertEqual(results, [('a', 'b', 'c')])
107+
108+
self.verify_same_shard_in_tracing(results, "shard 1")
109+
110+
bound = prepared.bind(('100000', 'f'))
111+
results = self.session.execute(bound, trace=True)
112+
self.assertEqual(results, [('100000', 'f', 'g')])
113+
114+
self.verify_same_shard_in_tracing(results, "shard 0")
115+
116+
bound = prepared.bind(('e', 'f'))
117+
results = self.session.execute(bound, trace=True)
118+
119+
self.verify_same_shard_in_tracing(results, "shard 1")

tests/unit/test_shard_aware.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
# Copyright DataStax, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
try:
16+
import unittest2 as unittest
17+
except ImportError:
18+
import unittest # noqa
19+
20+
from cassandra.connection import ShardingInfo
21+
from cassandra.metadata import Murmur3Token
22+
23+
class TestShardAware(unittest.TestCase):
24+
def test_parsing_and_calculating_shard_id(self):
25+
'''
26+
Testing the parsing of the options command
27+
and the calculation getting a shard id from a Murmur3 token
28+
'''
29+
class OptionsHolder():
30+
options = {
31+
'SCYLLA_SHARD': ['1'],
32+
'SCYLLA_NR_SHARDS': ['12'],
33+
'SCYLLA_PARTITIONER': ['org.apache.cassandra.dht.Murmur3Partitioner'],
34+
'SCYLLA_SHARDING_ALGORITHM': ['biased-token-round-robin'],
35+
'SCYLLA_SHARDING_IGNORE_MSB': ['12']
36+
}
37+
shard_id, shard_info = ShardingInfo.parse_sharding_info(OptionsHolder())
38+
39+
self.assertEqual(shard_id, 1)
40+
self.assertEqual(shard_info.shard_id(Murmur3Token.from_key(b"a")), 4)
41+
self.assertEqual(shard_info.shard_id(Murmur3Token.from_key(b"b")), 6)
42+
self.assertEqual(shard_info.shard_id(Murmur3Token.from_key(b"c")), 6)
43+
self.assertEqual(shard_info.shard_id(Murmur3Token.from_key(b"e")), 4)
44+
self.assertEqual(shard_info.shard_id(Murmur3Token.from_key(b"100000")), 2)

0 commit comments

Comments
 (0)