Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,4 @@ tests/unit/cython/bytesio_testhelper.c
#iPython
*.ipynb

test_slots_implementation.py
25 changes: 19 additions & 6 deletions cassandra/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@ class EndPoint(object):
"""
Represents the information to connect to a cassandra node.
"""
__slots__ = ()

@property
def address(self):
Expand Down Expand Up @@ -189,6 +190,7 @@ class DefaultEndPoint(EndPoint):
"""
Default EndPoint implementation, basically just an address and port.
"""
__slots__ = ('_address', '_port')

def __init__(self, address, port=9042):
self._address = address
Expand Down Expand Up @@ -251,6 +253,7 @@ def create(self, row):
@total_ordering
class SniEndPoint(EndPoint):
"""SNI Proxy EndPoint implementation."""
__slots__ = ('_proxy_address', '_index', '_resolved_address', '_port', '_server_name', '_ssl_options')

def __init__(self, proxy_address, server_name, port=9042):
self._proxy_address = proxy_address
Expand Down Expand Up @@ -330,6 +333,7 @@ class UnixSocketEndPoint(EndPoint):
"""
Unix Socket EndPoint implementation.
"""
__slots__ = ('_unix_socket_path',)

def __init__(self, unix_socket_path):
self._unix_socket_path = unix_socket_path
Expand Down Expand Up @@ -367,6 +371,8 @@ def __repr__(self):


class _Frame(object):
__slots__ = ('version', 'flags', 'stream', 'opcode', 'body_offset', 'end_pos')

def __init__(self, version, flags, stream, opcode, body_offset, end_pos):
self.version = version
self.flags = flags
Expand Down Expand Up @@ -470,6 +476,8 @@ def __init__(self, max_queue_size):


class ContinuousPagingSession(object):
__slots__ = ('stream_id', 'decoder', 'row_factory', 'connection', '_condition', '_stop', '_page_queue', '_state', 'released')

def __init__(self, stream_id, decoder, row_factory, connection, state):
self.stream_id = stream_id
self.decoder = decoder
Expand Down Expand Up @@ -624,14 +632,13 @@ class _ConnectionIOBuffer(object):
protocol V5 and checksumming, the data is read, validated and copied to another
cql frame buffer.
"""
_io_buffer = None
_cql_frame_buffer = None
_connection = None
_segment_consumed = False
__slots__ = ('_io_buffer', '_cql_frame_buffer', '_connection', '_segment_consumed')

def __init__(self, connection):
self._io_buffer = io.BytesIO()
self._connection = weakref.proxy(connection)
self._cql_frame_buffer = None
self._segment_consumed = False

@property
def io_buffer(self):
Expand Down Expand Up @@ -673,6 +680,8 @@ def reset_cql_frame_buffer(self):


class ShardawarePortGenerator:
__slots__ = ()

@classmethod
def generate(cls, shard_id, total_shards):
start = random.randrange(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH)
Expand Down Expand Up @@ -1593,6 +1602,7 @@ def __str__(self):


class ResponseWaiter(object):
__slots__ = ('connection', 'pending', 'fail_on_error', 'error', 'responses', 'event')

def __init__(self, connection, num_responses, fail_on_error):
self.connection = connection
Expand Down Expand Up @@ -1643,6 +1653,8 @@ def deliver(self, timeout=None):


class HeartbeatFuture(object):
__slots__ = ('_exception', '_event', 'connection', 'owner')

def __init__(self, connection, owner):
self._exception = None
self._event = Event()
Expand Down Expand Up @@ -1765,12 +1777,12 @@ def _raise_if_stopped(self):


class Timer(object):

canceled = False
__slots__ = ('end', 'callback', 'canceled')

def __init__(self, timeout, callback):
self.end = time.time() + timeout
self.callback = callback
self.canceled = False

def __lt__(self, other):
return self.end < other.end
Expand All @@ -1790,6 +1802,7 @@ def finish(self, time_now):


class TimerManager(object):
__slots__ = ('_queue', '_new_timers')

def __init__(self):
self._queue = []
Expand Down
113 changes: 113 additions & 0 deletions tests/unit/test_connection_slots.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
Test for __slots__ implementation in connection module classes.
This test ensures that memory optimization via __slots__ is working correctly.
"""
import unittest
from cassandra.connection import (
EndPoint, DefaultEndPoint, SniEndPoint, UnixSocketEndPoint,
_Frame, ContinuousPagingSession, ShardawarePortGenerator,
_ConnectionIOBuffer, ResponseWaiter, HeartbeatFuture, Timer, TimerManager
)


class SlotsImplementationTest(unittest.TestCase):
"""Test that targeted classes have __slots__ and prevent dynamic attributes."""

def test_endpoint_classes_have_slots(self):
"""Test EndPoint and its subclasses have __slots__ implemented."""
# EndPoint base class should have empty slots
self.assertEqual(EndPoint.__slots__, ())

# Test DefaultEndPoint
ep = DefaultEndPoint('127.0.0.1', 9042)
self.assertFalse(hasattr(ep, '__dict__'))
with self.assertRaises(AttributeError):
ep.dynamic_attr = 'test'

# Test SniEndPoint
sni_ep = SniEndPoint('proxy.example.com', 'server.example.com', 9042)
self.assertFalse(hasattr(sni_ep, '__dict__'))
with self.assertRaises(AttributeError):
sni_ep.dynamic_attr = 'test'

# Test UnixSocketEndPoint
unix_ep = UnixSocketEndPoint('/tmp/cassandra.sock')
self.assertFalse(hasattr(unix_ep, '__dict__'))
with self.assertRaises(AttributeError):
unix_ep.dynamic_attr = 'test'

def test_frame_class_has_slots(self):
"""Test _Frame class has __slots__ implemented."""
frame = _Frame(4, 0, 1, 7, 9, 100)
self.assertFalse(hasattr(frame, '__dict__'))
with self.assertRaises(AttributeError):
frame.dynamic_attr = 'test'

# Test that all expected attributes are accessible
self.assertEqual(frame.version, 4)
self.assertEqual(frame.flags, 0)
self.assertEqual(frame.stream, 1)
self.assertEqual(frame.opcode, 7)
self.assertEqual(frame.body_offset, 9)
self.assertEqual(frame.end_pos, 100)

def test_timer_classes_have_slots(self):
"""Test Timer and TimerManager classes have __slots__ implemented."""
# Test Timer
timer = Timer(5.0, lambda: None)
self.assertFalse(hasattr(timer, '__dict__'))
with self.assertRaises(AttributeError):
timer.dynamic_attr = 'test'

# Test Timer attributes
self.assertEqual(timer.canceled, False)
self.assertIsNotNone(timer.end)
self.assertIsNotNone(timer.callback)

# Test TimerManager
timer_mgr = TimerManager()
self.assertFalse(hasattr(timer_mgr, '__dict__'))
with self.assertRaises(AttributeError):
timer_mgr.dynamic_attr = 'test'

def test_utility_classes_have_slots(self):
"""Test utility classes have __slots__ implemented."""
# Test ShardawarePortGenerator
self.assertEqual(ShardawarePortGenerator.__slots__, ())

# Test _ConnectionIOBuffer
class MockConnection:
pass

io_buffer = _ConnectionIOBuffer(MockConnection())
self.assertFalse(hasattr(io_buffer, '__dict__'))
with self.assertRaises(AttributeError):
io_buffer.dynamic_attr = 'test'

# Test ResponseWaiter
response_waiter = ResponseWaiter(MockConnection(), 2, True)
self.assertFalse(hasattr(response_waiter, '__dict__'))
with self.assertRaises(AttributeError):
response_waiter.dynamic_attr = 'test'

def test_slots_prevent_memory_overhead(self):
"""Test that objects with __slots__ don't have __dict__ overhead."""
instances = [
DefaultEndPoint('127.0.0.1', 9042),
SniEndPoint('proxy.example.com', 'server.example.com', 9042),
UnixSocketEndPoint('/tmp/cassandra.sock'),
_Frame(4, 0, 1, 7, 9, 100),
Timer(5.0, lambda: None),
TimerManager(),
]

for instance in instances:
with self.subTest(instance=instance.__class__.__name__):
# Ensure no __dict__ is present (memory optimization)
self.assertFalse(hasattr(instance, '__dict__'))
# Ensure __slots__ is defined
self.assertTrue(hasattr(instance.__class__, '__slots__'))


if __name__ == '__main__':
unittest.main()