Skip to content

Add method to asynchronously prepare CQL statements #1239

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 98 additions & 30 deletions cassandra/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2717,7 +2717,7 @@ def execute_async(self, query, parameters=None, trace=False, custom_payload=None
if execute_as:
custom_payload[_proxy_execute_key] = execute_as.encode()

future = self._create_response_future(
future = self._create_execute_response_future(
query, parameters, trace, custom_payload, timeout,
execution_profile, paging_state, host)
future._protocol_handler = self.client_protocol_handler
Expand Down Expand Up @@ -2782,8 +2782,8 @@ def execute_graph_async(self, query, parameters=None, trace=False, execution_pro
custom_payload[_proxy_execute_key] = execute_as.encode()
custom_payload[_request_timeout_key] = int64_pack(int(execution_profile.request_timeout * 1000))

future = self._create_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload,
timeout=_NOT_SET, execution_profile=execution_profile)
future = self._create_execute_response_future(query, parameters=None, trace=trace, custom_payload=custom_payload,
timeout=_NOT_SET, execution_profile=execution_profile)

future.message.query_params = graph_parameters
future._protocol_handler = self.client_protocol_handler
Expand Down Expand Up @@ -2885,9 +2885,9 @@ def _transform_params(self, parameters, graph_options):

def _target_analytics_master(self, future):
future._start_timer()
master_query_future = self._create_response_future("CALL DseClientTool.getAnalyticsGraphServer()",
parameters=None, trace=False,
custom_payload=None, timeout=future.timeout)
master_query_future = self._create_execute_response_future("CALL DseClientTool.getAnalyticsGraphServer()",
parameters=None, trace=False,
custom_payload=None, timeout=future.timeout)
master_query_future.row_factory = tuple_factory
master_query_future.send_request()

Expand All @@ -2910,9 +2910,43 @@ def _on_analytics_master_result(self, response, master_future, query_future):

self.submit(query_future.send_request)

def _create_response_future(self, query, parameters, trace, custom_payload,
timeout, execution_profile=EXEC_PROFILE_DEFAULT,
paging_state=None, host=None):
def prepare_async(self, query, custom_payload=None, keyspace=None, prepare_on_all_hosts=None):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function should live near the impl for prepare(). So either it should be moved down below prepare() or we should bring prepare() up here.

"""
Prepare the given query and return a :class:`~.PrepareFuture`
object. You may also call :meth:`~.PrepareFuture.result()`
on the :class:`.PrepareFuture` to synchronously block for
prepared statement object at any time.

See :meth:`Session.prepare` for parameter definitions.

Example usage::

>>> future = session.prepare_async("SELECT * FROM mycf")
>>> # do other stuff...

>>> try:
... prepared_statement = future.result()
... except Exception:
... log.exception("Operation failed:")

When :meth:`~.Cluster.prepare_on_all_hosts` is enabled, method
attempts to prepare given query on all hosts, but does not wait
for their response.
"""
if prepare_on_all_hosts is None:
prepare_on_all_hosts = self.cluster.prepare_on_all_hosts
future = self._create_prepare_response_future(query, keyspace, custom_payload, prepare_on_all_hosts)
future._protocol_handler = self.client_protocol_handler
self._on_request(future)
future.send_request()
return future

def _create_prepare_response_future(self, query, keyspace, custom_payload, prepare_on_all_hosts):
return PrepareFuture(self, query, keyspace, custom_payload, self.default_timeout, prepare_on_all_hosts)

def _create_execute_response_future(self, query, parameters, trace, custom_payload,
timeout, execution_profile=EXEC_PROFILE_DEFAULT,
paging_state=None, host=None):
""" Returns the ResponseFuture before calling send_request() on it """

prepared_statement = None
Expand Down Expand Up @@ -3118,36 +3152,27 @@ def prepare(self, query, custom_payload=None, keyspace=None):
**Important**: PreparedStatements should be prepared only once.
Preparing the same query more than once will likely affect performance.

When :meth:`~.Cluster.prepare_on_all_hosts` is enabled, method
attempts to prepare given query on all hosts and waits for each node to respond.
Preparing CQL query on other nodes may fail, but error is not propagated
to the caller.

`custom_payload` is a key value map to be passed along with the prepare
message. See :ref:`custom_payload`.
"""
message = PrepareMessage(query=query, keyspace=keyspace)
future = ResponseFuture(self, message, query=None, timeout=self.default_timeout)
try:
future.send_request()
response = future.result().one()
except Exception:
log.exception("Error preparing query:")
raise

prepared_keyspace = keyspace if keyspace else None
prepared_statement = PreparedStatement.from_message(
response.query_id, response.bind_metadata, response.pk_indexes, self.cluster.metadata, query, prepared_keyspace,
self._protocol_version, response.column_metadata, response.result_metadata_id, self.cluster.column_encryption_policy)
prepared_statement.custom_payload = future.custom_payload

self.cluster.add_prepared(response.query_id, prepared_statement)

future = self.prepare_async(query, custom_payload, keyspace, prepare_on_all_hosts=False)
response = future.result()
if self.cluster.prepare_on_all_hosts:
# prepare on all hosts in a synchronous way, not asynchronously
# as internally in prepare_async() (PrepareFuture)
host = future._current_host
try:
self.prepare_on_all_hosts(prepared_statement.query_string, host, prepared_keyspace)
self.prepare_on_all_nodes(response.query_string, host, response.keyspace)
except Exception:
log.exception("Error preparing query on all hosts:")
return response
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems exactly right: prepare() should be implemented as prepare_async().get() once we have a good working prepare_async(). But why do we need to do the prepare_on_all_nodes() operation here? We should've already done that when future.result() completed (since it's done in PrepareFuture._set_final_result()).


return prepared_statement

def prepare_on_all_hosts(self, query, excluded_host, keyspace=None):
def prepare_on_all_nodes(self, query, excluded_host, keyspace=None):
"""
Prepare the given query on all hosts, excluding ``excluded_host``.
Intended for internal use only.
Expand Down Expand Up @@ -5105,6 +5130,49 @@ def __str__(self):
__repr__ = __str__


class PrepareFuture(ResponseFuture):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm extremely skeptical of the idea of extending ResponseFuture to a prepare-specific future implementation. There's a lot of functionality in ResponseFuture and we'd have to make sure everything we need there was duplicated here... and it's easy to miss things. Specifically ResponseFuture already has a lot of logic for dealing with prepare statements + responses... I'd rather find a way to re-use that and handle the prepare-on-all-hosts ops via callbacks (or perhaps something better) rather than subclass the future impl.

_final_prepare_result = _NOT_SET

def __init__(self, session, query, keyspace, custom_payload, timeout, prepare_on_all_hosts):
super().__init__(session, PrepareMessage(query=query, keyspace=keyspace), None, timeout)
self.query_string = query
self._prepare_on_all_hosts = prepare_on_all_hosts
self._keyspace = keyspace
self._custom_payload = custom_payload

def _set_final_result(self, response):
session = self.session
cluster = session.cluster
prepared_statement = PreparedStatement.from_message(
response.query_id, response.bind_metadata, response.pk_indexes, cluster.metadata, self.query_string,
self._keyspace, session._protocol_version, response.column_metadata, response.result_metadata_id,
cluster.column_encryption_policy)
prepared_statement.custom_payload = response.custom_payload
cluster.add_prepared(response.query_id, prepared_statement)
self._final_prepare_result = prepared_statement

if self._prepare_on_all_hosts:
# trigger asynchronous preparation of query on other C* nodes,
# we are on event loop thread, so do not execute those synchronously
session.submit(
session.prepare_on_all_nodes,
self.query_string, self._current_host, self._keyspace)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems strange to me to have this logic embedded within the future impl like this. Seems like this should be done when the future is created, something like:

# _create_prepare_response_future() in this impl returns a regular ResponseFuture
future = self._create_prepare_response_future(query, keyspace, custom_payload, prepare_on_all_hosts)
if prepare_on_all_hosts:
   # Hand waving about partial application here in order to pass in parameters; the point is we get to a future that
   # calls prepare_on_all_hosts() after we've received a response to our prepare here here
   future = future.then(prepare_on_all_hosts)
future._protocol_handler = self.client_protocol_handler

Problem of course is the ResponseFuture doesn't support this then() syntax. It does have native support for callbacks but that isn't the same; that's just a function that gets invoked when the operation completes. We don't return a new future that returns the result of the function defined in the then() call like you do in most future APIs.

I'm wondering if we can either (a) add something like that or (b) find a way to wrap this functionality in another future lib in order to simplify an impl like this.


super()._set_final_result(response)

def result(self):
self._event.wait()
if self._final_prepare_result is not _NOT_SET:
return self._final_prepare_result
else:
raise self._final_exception

def __str__(self):
result = "(no result yet)" if self._final_result is _NOT_SET else self._final_result
return "<PrepareFuture: query='%s' request_id=%s result=%s exception=%s coordinator_host=%s>" \
% (self.query_string, self._req_id, result, self._final_exception, self.coordinator_host)
__repr__ = __str__

class QueryExhausted(Exception):
"""
Raised when :meth:`.ResponseFuture.start_fetching_next_page()` is called and
Expand Down
2 changes: 1 addition & 1 deletion cassandra/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,7 +1035,7 @@ def populate(self, max_wait=2.0, wait_for_complete=True, query_cl=None):

def _execute(self, query, parameters, time_spent, max_wait):
timeout = (max_wait - time_spent) if max_wait is not None else None
future = self._session._create_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout)
future = self._session._create_execute_response_future(query, parameters, trace=False, custom_payload=None, timeout=timeout)
# in case the user switched the row factory, set it to namedtuple for this query
future.row_factory = named_tuple_factory
future.send_request()
Expand Down
78 changes: 78 additions & 0 deletions tests/integration/standard/test_prepared_statements.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from cassandra import InvalidRequest, DriverException

from cassandra import ConsistencyLevel, ProtocolVersion
from cassandra.cluster import PrepareFuture
from cassandra.query import PreparedStatement, UNSET_VALUE
from tests.integration import (get_server_versions, greaterthanorequalcass40, greaterthanorequaldse50,
requirecassandra, BasicSharedKeyspaceUnitTestCase)
Expand Down Expand Up @@ -121,6 +122,83 @@ def test_basic(self):
results = self.session.execute(bound)
self.assertEqual(results, [('x', 'y', 'z')])

def test_basic_async(self):
"""
Test basic asynchronous PreparedStatement usage
"""
self.session.execute(
"""
DROP KEYSPACE IF EXISTS preparedtests
"""
)
self.session.execute(
"""
CREATE KEYSPACE preparedtests
WITH replication = {'class': 'SimpleStrategy', 'replication_factor': '1'}
""")

self.session.set_keyspace("preparedtests")
self.session.execute(
"""
CREATE TABLE cf0 (
a text,
b text,
c text,
PRIMARY KEY (a, b)
)
""")

prepared_future = self.session.prepare_async(
"""
INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?)
""")
self.assertIsInstance(prepared_future, PrepareFuture)
prepared = prepared_future.result()
self.assertIsInstance(prepared, PreparedStatement)

bound = prepared.bind(('a', 'b', 'c'))
self.session.execute(bound)

prepared_future = self.session.prepare_async(
"""
SELECT * FROM cf0 WHERE a=?
""")
self.assertIsInstance(prepared_future, PrepareFuture)
prepared = prepared_future.result()
self.assertIsInstance(prepared, PreparedStatement)

bound = prepared.bind(('a'))
results = self.session.execute(bound)
self.assertEqual(results, [('a', 'b', 'c')])

# test with new dict binding
prepared_future = self.session.prepare_async(
"""
INSERT INTO cf0 (a, b, c) VALUES (?, ?, ?)
""")
self.assertIsInstance(prepared_future, PrepareFuture)
prepared = prepared_future.result()
self.assertIsInstance(prepared, PreparedStatement)

bound = prepared.bind({
'a': 'x',
'b': 'y',
'c': 'z'
})
self.session.execute(bound)

prepared_future = self.session.prepare_async(
"""
SELECT * FROM cf0 WHERE a=?
""")
self.assertIsInstance(prepared_future, PrepareFuture)
prepared = prepared_future.result()
self.assertIsInstance(prepared, PreparedStatement)

bound = prepared.bind({'a': 'x'})
results = self.session.execute(bound)
self.assertEqual(results, [('x', 'y', 'z')])

def test_missing_primary_key(self):
"""
Ensure an InvalidRequest is thrown
Expand Down
28 changes: 26 additions & 2 deletions tests/integration/standard/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,20 @@ def test_prepare_on_all_hosts(self):
session.execute(select_statement, (1, ), host=host)
self.assertEqual(2, self.mock_handler.get_message_count('debug', "Re-preparing"))

def test_prepare_async_on_all_hosts(self):
"""
Test to validate prepare_on_all_hosts flag is honored during prepare_async execution.
"""
clus = TestCluster(prepare_on_all_hosts=True)
self.addCleanup(clus.shutdown)

session = clus.connect(wait_for_all_pools=True)
select_statement = session.prepare_async("SELECT k FROM test3rf.test WHERE k = ?").result()
time.sleep(1) # we have no way to know when prepared statements are asynchronously completed
for host in clus.metadata.all_hosts():
session.execute(select_statement, (1, ), host=host)
self.assertEqual(0, self.mock_handler.get_message_count('debug', "Re-preparing"))

def test_prepare_batch_statement(self):
"""
Test to validate a prepared statement used inside a batch statement is correctly handled
Expand Down Expand Up @@ -647,7 +661,6 @@ def test_prepared_statement(self):

prepared = session.prepare('INSERT INTO test3rf.test (k, v) VALUES (?, ?)')
prepared.consistency_level = ConsistencyLevel.ONE

self.assertEqual(str(prepared),
'<PreparedStatement query="INSERT INTO test3rf.test (k, v) VALUES (?, ?)", consistency=ONE>')

Expand Down Expand Up @@ -717,6 +730,17 @@ def test_prepared_statements(self):
self.session.execute_async(batch).result()
self.confirm_results()

def test_prepare_async(self):
prepared = self.session.prepare_async("INSERT INTO test3rf.test (k, v) VALUES (?, ?)").result()

batch = BatchStatement(BatchType.LOGGED)
for i in range(10):
batch.add(prepared, (i, i))

self.session.execute(batch)
self.session.execute_async(batch).result()
self.confirm_results()

def test_bound_statements(self):
prepared = self.session.prepare("INSERT INTO test3rf.test (k, v) VALUES (?, ?)")

Expand Down Expand Up @@ -942,7 +966,7 @@ def test_no_connection_refused_on_timeout(self):
exception_type = type(result).__name__
if exception_type == "NoHostAvailable":
self.fail("PYTHON-91: Disconnected from Cassandra: %s" % result.message)
if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessageSub"]:
if exception_type in ["WriteTimeout", "WriteFailure", "ReadTimeout", "ReadFailure", "ErrorMessageSub", "ErrorMessage"]:
if type(result).__name__ in ["WriteTimeout", "WriteFailure"]:
received_timeout = True
continue
Expand Down