Skip to content

Commit dedf96b

Browse files
Copilotmykaul
andcommitted
Add __slots__ to 12 lightweight connection classes for memory optimization
Co-authored-by: mykaul <[email protected]>
1 parent 64b8490 commit dedf96b

File tree

3 files changed

+133
-6
lines changed

3 files changed

+133
-6
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,4 @@ tests/unit/cython/bytesio_testhelper.c
4444
#iPython
4545
*.ipynb
4646

47+
test_slots_implementation.py

cassandra/connection.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,7 @@ class EndPoint(object):
129129
"""
130130
Represents the information to connect to a cassandra node.
131131
"""
132+
__slots__ = ()
132133

133134
@property
134135
def address(self):
@@ -189,6 +190,7 @@ class DefaultEndPoint(EndPoint):
189190
"""
190191
Default EndPoint implementation, basically just an address and port.
191192
"""
193+
__slots__ = ('_address', '_port')
192194

193195
def __init__(self, address, port=9042):
194196
self._address = address
@@ -251,6 +253,7 @@ def create(self, row):
251253
@total_ordering
252254
class SniEndPoint(EndPoint):
253255
"""SNI Proxy EndPoint implementation."""
256+
__slots__ = ('_proxy_address', '_index', '_resolved_address', '_port', '_server_name', '_ssl_options')
254257

255258
def __init__(self, proxy_address, server_name, port=9042):
256259
self._proxy_address = proxy_address
@@ -330,6 +333,7 @@ class UnixSocketEndPoint(EndPoint):
330333
"""
331334
Unix Socket EndPoint implementation.
332335
"""
336+
__slots__ = ('_unix_socket_path',)
333337

334338
def __init__(self, unix_socket_path):
335339
self._unix_socket_path = unix_socket_path
@@ -367,6 +371,8 @@ def __repr__(self):
367371

368372

369373
class _Frame(object):
374+
__slots__ = ('version', 'flags', 'stream', 'opcode', 'body_offset', 'end_pos')
375+
370376
def __init__(self, version, flags, stream, opcode, body_offset, end_pos):
371377
self.version = version
372378
self.flags = flags
@@ -470,6 +476,8 @@ def __init__(self, max_queue_size):
470476

471477

472478
class ContinuousPagingSession(object):
479+
__slots__ = ('stream_id', 'decoder', 'row_factory', 'connection', '_condition', '_stop', '_page_queue', '_state', 'released')
480+
473481
def __init__(self, stream_id, decoder, row_factory, connection, state):
474482
self.stream_id = stream_id
475483
self.decoder = decoder
@@ -624,14 +632,13 @@ class _ConnectionIOBuffer(object):
624632
protocol V5 and checksumming, the data is read, validated and copied to another
625633
cql frame buffer.
626634
"""
627-
_io_buffer = None
628-
_cql_frame_buffer = None
629-
_connection = None
630-
_segment_consumed = False
635+
__slots__ = ('_io_buffer', '_cql_frame_buffer', '_connection', '_segment_consumed')
631636

632637
def __init__(self, connection):
633638
self._io_buffer = io.BytesIO()
634639
self._connection = weakref.proxy(connection)
640+
self._cql_frame_buffer = None
641+
self._segment_consumed = False
635642

636643
@property
637644
def io_buffer(self):
@@ -673,6 +680,8 @@ def reset_cql_frame_buffer(self):
673680

674681

675682
class ShardawarePortGenerator:
683+
__slots__ = ()
684+
676685
@classmethod
677686
def generate(cls, shard_id, total_shards):
678687
start = random.randrange(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH)
@@ -1593,6 +1602,7 @@ def __str__(self):
15931602

15941603

15951604
class ResponseWaiter(object):
1605+
__slots__ = ('connection', 'pending', 'fail_on_error', 'error', 'responses', 'event')
15961606

15971607
def __init__(self, connection, num_responses, fail_on_error):
15981608
self.connection = connection
@@ -1643,6 +1653,8 @@ def deliver(self, timeout=None):
16431653

16441654

16451655
class HeartbeatFuture(object):
1656+
__slots__ = ('_exception', '_event', 'connection', 'owner')
1657+
16461658
def __init__(self, connection, owner):
16471659
self._exception = None
16481660
self._event = Event()
@@ -1765,12 +1777,12 @@ def _raise_if_stopped(self):
17651777

17661778

17671779
class Timer(object):
1768-
1769-
canceled = False
1780+
__slots__ = ('end', 'callback', 'canceled')
17701781

17711782
def __init__(self, timeout, callback):
17721783
self.end = time.time() + timeout
17731784
self.callback = callback
1785+
self.canceled = False
17741786

17751787
def __lt__(self, other):
17761788
return self.end < other.end
@@ -1790,6 +1802,7 @@ def finish(self, time_now):
17901802

17911803

17921804
class TimerManager(object):
1805+
__slots__ = ('_queue', '_new_timers')
17931806

17941807
def __init__(self):
17951808
self._queue = []
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""
2+
Test for __slots__ implementation in connection module classes.
3+
This test ensures that memory optimization via __slots__ is working correctly.
4+
"""
5+
import unittest
6+
from cassandra.connection import (
7+
EndPoint, DefaultEndPoint, SniEndPoint, UnixSocketEndPoint,
8+
_Frame, ContinuousPagingSession, ShardawarePortGenerator,
9+
_ConnectionIOBuffer, ResponseWaiter, HeartbeatFuture, Timer, TimerManager
10+
)
11+
12+
13+
class SlotsImplementationTest(unittest.TestCase):
14+
"""Test that targeted classes have __slots__ and prevent dynamic attributes."""
15+
16+
def test_endpoint_classes_have_slots(self):
17+
"""Test EndPoint and its subclasses have __slots__ implemented."""
18+
# EndPoint base class should have empty slots
19+
self.assertEqual(EndPoint.__slots__, ())
20+
21+
# Test DefaultEndPoint
22+
ep = DefaultEndPoint('127.0.0.1', 9042)
23+
self.assertFalse(hasattr(ep, '__dict__'))
24+
with self.assertRaises(AttributeError):
25+
ep.dynamic_attr = 'test'
26+
27+
# Test SniEndPoint
28+
sni_ep = SniEndPoint('proxy.example.com', 'server.example.com', 9042)
29+
self.assertFalse(hasattr(sni_ep, '__dict__'))
30+
with self.assertRaises(AttributeError):
31+
sni_ep.dynamic_attr = 'test'
32+
33+
# Test UnixSocketEndPoint
34+
unix_ep = UnixSocketEndPoint('/tmp/cassandra.sock')
35+
self.assertFalse(hasattr(unix_ep, '__dict__'))
36+
with self.assertRaises(AttributeError):
37+
unix_ep.dynamic_attr = 'test'
38+
39+
def test_frame_class_has_slots(self):
40+
"""Test _Frame class has __slots__ implemented."""
41+
frame = _Frame(4, 0, 1, 7, 9, 100)
42+
self.assertFalse(hasattr(frame, '__dict__'))
43+
with self.assertRaises(AttributeError):
44+
frame.dynamic_attr = 'test'
45+
46+
# Test that all expected attributes are accessible
47+
self.assertEqual(frame.version, 4)
48+
self.assertEqual(frame.flags, 0)
49+
self.assertEqual(frame.stream, 1)
50+
self.assertEqual(frame.opcode, 7)
51+
self.assertEqual(frame.body_offset, 9)
52+
self.assertEqual(frame.end_pos, 100)
53+
54+
def test_timer_classes_have_slots(self):
55+
"""Test Timer and TimerManager classes have __slots__ implemented."""
56+
# Test Timer
57+
timer = Timer(5.0, lambda: None)
58+
self.assertFalse(hasattr(timer, '__dict__'))
59+
with self.assertRaises(AttributeError):
60+
timer.dynamic_attr = 'test'
61+
62+
# Test Timer attributes
63+
self.assertEqual(timer.canceled, False)
64+
self.assertIsNotNone(timer.end)
65+
self.assertIsNotNone(timer.callback)
66+
67+
# Test TimerManager
68+
timer_mgr = TimerManager()
69+
self.assertFalse(hasattr(timer_mgr, '__dict__'))
70+
with self.assertRaises(AttributeError):
71+
timer_mgr.dynamic_attr = 'test'
72+
73+
def test_utility_classes_have_slots(self):
74+
"""Test utility classes have __slots__ implemented."""
75+
# Test ShardawarePortGenerator
76+
self.assertEqual(ShardawarePortGenerator.__slots__, ())
77+
78+
# Test _ConnectionIOBuffer
79+
class MockConnection:
80+
pass
81+
82+
io_buffer = _ConnectionIOBuffer(MockConnection())
83+
self.assertFalse(hasattr(io_buffer, '__dict__'))
84+
with self.assertRaises(AttributeError):
85+
io_buffer.dynamic_attr = 'test'
86+
87+
# Test ResponseWaiter
88+
response_waiter = ResponseWaiter(MockConnection(), 2, True)
89+
self.assertFalse(hasattr(response_waiter, '__dict__'))
90+
with self.assertRaises(AttributeError):
91+
response_waiter.dynamic_attr = 'test'
92+
93+
def test_slots_prevent_memory_overhead(self):
94+
"""Test that objects with __slots__ don't have __dict__ overhead."""
95+
instances = [
96+
DefaultEndPoint('127.0.0.1', 9042),
97+
SniEndPoint('proxy.example.com', 'server.example.com', 9042),
98+
UnixSocketEndPoint('/tmp/cassandra.sock'),
99+
_Frame(4, 0, 1, 7, 9, 100),
100+
Timer(5.0, lambda: None),
101+
TimerManager(),
102+
]
103+
104+
for instance in instances:
105+
with self.subTest(instance=instance.__class__.__name__):
106+
# Ensure no __dict__ is present (memory optimization)
107+
self.assertFalse(hasattr(instance, '__dict__'))
108+
# Ensure __slots__ is defined
109+
self.assertTrue(hasattr(instance.__class__, '__slots__'))
110+
111+
112+
if __name__ == '__main__':
113+
unittest.main()

0 commit comments

Comments
 (0)