Skip to content

Commit d546cac

Browse files
committed
1
1 parent c1809c1 commit d546cac

File tree

3 files changed

+70
-15
lines changed

3 files changed

+70
-15
lines changed

cassandra/cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3111,7 +3111,7 @@ def _create_response_future(self, query, parameters, trace, custom_payload,
31113111
message = ExecuteMessage(
31123112
prepared_statement.query_id, query.values, cl,
31133113
serial_cl, fetch_size, paging_state, timestamp,
3114-
skip_meta=bool(prepared_statement.result_metadata),
3114+
can_have_result_metadata=bool(prepared_statement.result_metadata),
31153115
continuous_paging_options=continuous_paging_options,
31163116
result_metadata_id=prepared_statement.result_metadata_id)
31173117
elif isinstance(query, BatchStatement):

cassandra/protocol.py

Lines changed: 47 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from collections import namedtuple
1717
import logging
1818
import socket
19+
from typing import Optional
1920
from uuid import UUID
2021

2122
import io
@@ -42,6 +43,7 @@
4243
from cassandra import WriteType
4344
from cassandra.cython_deps import HAVE_CYTHON, HAVE_NUMPY
4445
from cassandra import util
46+
from cassandra.protocol_features import ProtocolFeatures
4547

4648
log = logging.getLogger(__name__)
4749

@@ -541,11 +543,35 @@ def recv_body(cls, f, *args):
541543

542544

543545
class _QueryMessage(_MessageType):
544-
546+
"""
547+
Represents a query frame sent to the database.
548+
549+
This message encapsulates all parameters required for executing a query,
550+
including consistency settings, paging controls, and metadata options.
551+
552+
Attributes:
553+
query_params: Encoded query parameters or a prepared statement ID with bound values.
554+
consistency_level: The desired consistency level for the query.
555+
serial_consistency_level: Optional serial consistency level for conditional updates (e.g., LOCAL_SERIAL).
556+
fetch_size: Optional number of rows to fetch per page.
557+
paging_state: Optional opaque paging state token for continuing from a previous query.
558+
timestamp: Optional client-supplied timestamp for the query.
559+
skip_meta: Optional flag indicating if result metadata should be skipped in the response.
560+
continuous_paging_options: Optional configuration for continuous paging behavior.
561+
keyspace: Optional keyspace to associate with the query.
562+
can_have_result_metadata: Optional flag indicating if the query is expected to have result metadata.
563+
564+
When skip_meta is not set to True or False it resolves it the safest way possible.
565+
Every protocol before 5 has result metadata invalidation loopholes, when metadata on server
566+
and client stays different without client noticing it.
567+
To solve this problem for protocol 4 scylla introduced SCYLLA_USE_METADATA_ID feature
568+
that allows client spot when result metadata on client divert from server.
569+
Read https://github.com/scylladb/scylla-drivers/issues/81 for more details
570+
"""
545571
def __init__(self, query_params, consistency_level,
546572
serial_consistency_level=None, fetch_size=None,
547-
paging_state=None, timestamp=None, skip_meta=False,
548-
continuous_paging_options=None, keyspace=None):
573+
paging_state=None, timestamp=None, skip_meta=None,
574+
continuous_paging_options=None, keyspace=None, can_have_result_metadata: bool = False):
549575
self.query_params = query_params
550576
self.consistency_level = consistency_level
551577
self.serial_consistency_level = serial_consistency_level
@@ -555,8 +581,9 @@ def __init__(self, query_params, consistency_level,
555581
self.skip_meta = skip_meta
556582
self.continuous_paging_options = continuous_paging_options
557583
self.keyspace = keyspace
584+
self.can_have_result_metadata = can_have_result_metadata
558585

559-
def _write_query_params(self, f, protocol_version):
586+
def _write_query_params(self, f, protocol_version, protocol_features: ProtocolFeatures):
560587
write_consistency_level(f, self.consistency_level)
561588
flags = 0x00
562589
if self.query_params is not None:
@@ -606,7 +633,12 @@ def _write_query_params(self, f, protocol_version):
606633
"Keyspaces may only be set on queries with protocol version "
607634
"5 or DSE_V2 or higher. Consider setting Cluster.protocol_version.")
608635

609-
if self.skip_meta is not None and self.skip_meta:
636+
if self.skip_meta:
637+
flags |= _SKIP_METADATA_FLAG
638+
elif self.skip_meta is None and self.can_have_result_metadata and (protocol_version >= 5 or protocol_features.use_metadata_id):
639+
# Skip metadata only when protocol allows to invalidate result metadata properly
640+
# i.e. protocol version >= 5 or SCYLLA_USE_METADATA_ID present
641+
# Read https://github.com/scylladb/scylla-drivers/issues/81 for more details
610642
flags |= _SKIP_METADATA_FLAG
611643

612644
if ProtocolVersion.uses_int_query_flags(protocol_version):
@@ -648,9 +680,9 @@ def __init__(self, query, consistency_level, serial_consistency_level=None,
648680
super(QueryMessage, self).__init__(None, consistency_level, serial_consistency_level, fetch_size,
649681
paging_state, timestamp, False, continuous_paging_options, keyspace)
650682

651-
def send_body(self, f, protocol_version, protocol_features):
683+
def send_body(self, f, protocol_version, protocol_features: ProtocolFeatures):
652684
write_longstring(f, self.query)
653-
self._write_query_params(f, protocol_version)
685+
self._write_query_params(f, protocol_version, protocol_features)
654686

655687

656688
class ExecuteMessage(_QueryMessage):
@@ -659,14 +691,14 @@ class ExecuteMessage(_QueryMessage):
659691

660692
def __init__(self, query_id, query_params, consistency_level,
661693
serial_consistency_level=None, fetch_size=None,
662-
paging_state=None, timestamp=None, skip_meta=False,
663-
continuous_paging_options=None, result_metadata_id=None):
694+
paging_state=None, timestamp=None, skip_meta=None,
695+
continuous_paging_options=None, result_metadata_id=None, can_have_result_metadata: bool = False):
664696
self.query_id = query_id
665697
self.result_metadata_id = result_metadata_id
666698
super(ExecuteMessage, self).__init__(query_params, consistency_level, serial_consistency_level, fetch_size,
667-
paging_state, timestamp, skip_meta, continuous_paging_options)
699+
paging_state, timestamp, skip_meta, continuous_paging_options, can_have_result_metadata=can_have_result_metadata)
668700

669-
def _write_query_params(self, f, protocol_version):
701+
def _write_query_params(self, f, protocol_version, protocol_features: ProtocolFeatures):
670702
if protocol_version == 1:
671703
if self.serial_consistency_level:
672704
raise UnsupportedOperation(
@@ -682,13 +714,13 @@ def _write_query_params(self, f, protocol_version):
682714
write_value(f, param)
683715
write_consistency_level(f, self.consistency_level)
684716
else:
685-
super(ExecuteMessage, self)._write_query_params(f, protocol_version)
717+
super(ExecuteMessage, self)._write_query_params(f, protocol_version, protocol_features)
686718

687719
def send_body(self, f, protocol_version, protocol_features):
688720
write_string(f, self.query_id)
689721
if ProtocolVersion.uses_prepared_metadata(protocol_version) or protocol_features.use_metadata_id:
690-
write_string(f, self.result_metadata_id)
691-
self._write_query_params(f, protocol_version)
722+
write_string(f, self.result_metadata_id or "")
723+
self._write_query_params(f, protocol_version, protocol_features)
692724

693725

694726
CUSTOM_TYPE = object()
@@ -733,6 +765,7 @@ class ResultMessage(_MessageType):
733765
bind_metadata = None
734766
pk_indexes = None
735767
schema_change_event = None
768+
result_metadata_id = None
736769

737770
def __init__(self, kind):
738771
self.kind = kind

tests/unit/test_protocol.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
from unittest.mock import Mock
1818

19+
import pytest
20+
1921
from cassandra import ProtocolVersion, UnsupportedOperation
2022
from cassandra.protocol import (
2123
PrepareMessage, QueryMessage, ExecuteMessage, UnsupportedOperation,
@@ -68,6 +70,26 @@ def test_execute_message(self):
6870
(b'\x00\x04',),
6971
(b'\x00\x00\x00\x01',), (b'\x00\x00',)])
7072

73+
def test_execute_metadata_id(self):
74+
for protocol, use_metadata_id, can_have_result_metadata in [
75+
(4, True, True), (4, True, False), (4, False, True), (4, False, False), (5, True, True), (5, True, False),
76+
(5, False, True), (5, False, False)]:
77+
print(protocol, use_metadata_id, can_have_result_metadata)
78+
message = ExecuteMessage('1', [], 4, can_have_result_metadata=can_have_result_metadata)
79+
io = Mock()
80+
81+
message.send_body(io, protocol, ProtocolFeatures(use_metadata_id=use_metadata_id))
82+
self._check_calls(io, [(b'\x00\x01',), (b'1',), (b'\x00\x04',), (b'\x01',), (b'\x00\x00',)])
83+
84+
io.reset_mock()
85+
message.result_metadata_id = 'foo'
86+
message.send_body(io, 5, ProtocolFeatures())
87+
88+
self._check_calls(io, [(b'\x00\x01',), (b'1',),
89+
(b'\x00\x03',), (b'foo',),
90+
(b'\x00\x04',),
91+
(b'\x00\x00\x00\x01',), (b'\x00\x00',)])
92+
7193
def test_query_message(self):
7294
"""
7395
Test to check the appropriate calls are made

0 commit comments

Comments
 (0)