diff --git a/.gitignore b/.gitignore index 4541d034f0..6ebda6c0cf 100644 --- a/.gitignore +++ b/.gitignore @@ -44,3 +44,4 @@ tests/unit/cython/bytesio_testhelper.c #iPython *.ipynb +test_slots_implementation.py diff --git a/cassandra/connection.py b/cassandra/connection.py index c3ba42d725..c7a4f5db9c 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -129,6 +129,7 @@ class EndPoint(object): """ Represents the information to connect to a cassandra node. """ + __slots__ = () @property def address(self): @@ -189,6 +190,7 @@ class DefaultEndPoint(EndPoint): """ Default EndPoint implementation, basically just an address and port. """ + __slots__ = ('_address', '_port') def __init__(self, address, port=9042): self._address = address @@ -251,6 +253,7 @@ def create(self, row): @total_ordering class SniEndPoint(EndPoint): """SNI Proxy EndPoint implementation.""" + __slots__ = ('_proxy_address', '_index', '_resolved_address', '_port', '_server_name', '_ssl_options') def __init__(self, proxy_address, server_name, port=9042): self._proxy_address = proxy_address @@ -330,6 +333,7 @@ class UnixSocketEndPoint(EndPoint): """ Unix Socket EndPoint implementation. """ + __slots__ = ('_unix_socket_path',) def __init__(self, unix_socket_path): self._unix_socket_path = unix_socket_path @@ -367,6 +371,8 @@ def __repr__(self): class _Frame(object): + __slots__ = ('version', 'flags', 'stream', 'opcode', 'body_offset', 'end_pos') + def __init__(self, version, flags, stream, opcode, body_offset, end_pos): self.version = version self.flags = flags @@ -470,6 +476,8 @@ def __init__(self, max_queue_size): class ContinuousPagingSession(object): + __slots__ = ('stream_id', 'decoder', 'row_factory', 'connection', '_condition', '_stop', '_page_queue', '_state', 'released') + def __init__(self, stream_id, decoder, row_factory, connection, state): self.stream_id = stream_id self.decoder = decoder @@ -624,14 +632,13 @@ class _ConnectionIOBuffer(object): protocol V5 and checksumming, the data is read, validated and copied to another cql frame buffer. """ - _io_buffer = None - _cql_frame_buffer = None - _connection = None - _segment_consumed = False + __slots__ = ('_io_buffer', '_cql_frame_buffer', '_connection', '_segment_consumed') def __init__(self, connection): self._io_buffer = io.BytesIO() self._connection = weakref.proxy(connection) + self._cql_frame_buffer = None + self._segment_consumed = False @property def io_buffer(self): @@ -673,6 +680,8 @@ def reset_cql_frame_buffer(self): class ShardawarePortGenerator: + __slots__ = () + @classmethod def generate(cls, shard_id, total_shards): start = random.randrange(DEFAULT_LOCAL_PORT_LOW, DEFAULT_LOCAL_PORT_HIGH) @@ -1593,6 +1602,7 @@ def __str__(self): class ResponseWaiter(object): + __slots__ = ('connection', 'pending', 'fail_on_error', 'error', 'responses', 'event') def __init__(self, connection, num_responses, fail_on_error): self.connection = connection @@ -1643,6 +1653,8 @@ def deliver(self, timeout=None): class HeartbeatFuture(object): + __slots__ = ('_exception', '_event', 'connection', 'owner') + def __init__(self, connection, owner): self._exception = None self._event = Event() @@ -1765,12 +1777,12 @@ def _raise_if_stopped(self): class Timer(object): - - canceled = False + __slots__ = ('end', 'callback', 'canceled') def __init__(self, timeout, callback): self.end = time.time() + timeout self.callback = callback + self.canceled = False def __lt__(self, other): return self.end < other.end @@ -1790,6 +1802,7 @@ def finish(self, time_now): class TimerManager(object): + __slots__ = ('_queue', '_new_timers') def __init__(self): self._queue = [] diff --git a/tests/unit/test_connection_slots.py b/tests/unit/test_connection_slots.py new file mode 100644 index 0000000000..21de9995b5 --- /dev/null +++ b/tests/unit/test_connection_slots.py @@ -0,0 +1,113 @@ +""" +Test for __slots__ implementation in connection module classes. +This test ensures that memory optimization via __slots__ is working correctly. +""" +import unittest +from cassandra.connection import ( + EndPoint, DefaultEndPoint, SniEndPoint, UnixSocketEndPoint, + _Frame, ContinuousPagingSession, ShardawarePortGenerator, + _ConnectionIOBuffer, ResponseWaiter, HeartbeatFuture, Timer, TimerManager +) + + +class SlotsImplementationTest(unittest.TestCase): + """Test that targeted classes have __slots__ and prevent dynamic attributes.""" + + def test_endpoint_classes_have_slots(self): + """Test EndPoint and its subclasses have __slots__ implemented.""" + # EndPoint base class should have empty slots + self.assertEqual(EndPoint.__slots__, ()) + + # Test DefaultEndPoint + ep = DefaultEndPoint('127.0.0.1', 9042) + self.assertFalse(hasattr(ep, '__dict__')) + with self.assertRaises(AttributeError): + ep.dynamic_attr = 'test' + + # Test SniEndPoint + sni_ep = SniEndPoint('proxy.example.com', 'server.example.com', 9042) + self.assertFalse(hasattr(sni_ep, '__dict__')) + with self.assertRaises(AttributeError): + sni_ep.dynamic_attr = 'test' + + # Test UnixSocketEndPoint + unix_ep = UnixSocketEndPoint('/tmp/cassandra.sock') + self.assertFalse(hasattr(unix_ep, '__dict__')) + with self.assertRaises(AttributeError): + unix_ep.dynamic_attr = 'test' + + def test_frame_class_has_slots(self): + """Test _Frame class has __slots__ implemented.""" + frame = _Frame(4, 0, 1, 7, 9, 100) + self.assertFalse(hasattr(frame, '__dict__')) + with self.assertRaises(AttributeError): + frame.dynamic_attr = 'test' + + # Test that all expected attributes are accessible + self.assertEqual(frame.version, 4) + self.assertEqual(frame.flags, 0) + self.assertEqual(frame.stream, 1) + self.assertEqual(frame.opcode, 7) + self.assertEqual(frame.body_offset, 9) + self.assertEqual(frame.end_pos, 100) + + def test_timer_classes_have_slots(self): + """Test Timer and TimerManager classes have __slots__ implemented.""" + # Test Timer + timer = Timer(5.0, lambda: None) + self.assertFalse(hasattr(timer, '__dict__')) + with self.assertRaises(AttributeError): + timer.dynamic_attr = 'test' + + # Test Timer attributes + self.assertEqual(timer.canceled, False) + self.assertIsNotNone(timer.end) + self.assertIsNotNone(timer.callback) + + # Test TimerManager + timer_mgr = TimerManager() + self.assertFalse(hasattr(timer_mgr, '__dict__')) + with self.assertRaises(AttributeError): + timer_mgr.dynamic_attr = 'test' + + def test_utility_classes_have_slots(self): + """Test utility classes have __slots__ implemented.""" + # Test ShardawarePortGenerator + self.assertEqual(ShardawarePortGenerator.__slots__, ()) + + # Test _ConnectionIOBuffer + class MockConnection: + pass + + io_buffer = _ConnectionIOBuffer(MockConnection()) + self.assertFalse(hasattr(io_buffer, '__dict__')) + with self.assertRaises(AttributeError): + io_buffer.dynamic_attr = 'test' + + # Test ResponseWaiter + response_waiter = ResponseWaiter(MockConnection(), 2, True) + self.assertFalse(hasattr(response_waiter, '__dict__')) + with self.assertRaises(AttributeError): + response_waiter.dynamic_attr = 'test' + + def test_slots_prevent_memory_overhead(self): + """Test that objects with __slots__ don't have __dict__ overhead.""" + instances = [ + DefaultEndPoint('127.0.0.1', 9042), + SniEndPoint('proxy.example.com', 'server.example.com', 9042), + UnixSocketEndPoint('/tmp/cassandra.sock'), + _Frame(4, 0, 1, 7, 9, 100), + Timer(5.0, lambda: None), + TimerManager(), + ] + + for instance in instances: + with self.subTest(instance=instance.__class__.__name__): + # Ensure no __dict__ is present (memory optimization) + self.assertFalse(hasattr(instance, '__dict__')) + # Ensure __slots__ is defined + self.assertTrue(hasattr(instance.__class__, '__slots__')) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file