Skip to content

Commit 40a7c35

Browse files
authored
Merge pull request scylladb#10 from riptano/unified-dse-graph
PYTHON-1113: Unify cassandra-driver and dse-graph repositories
2 parents 0a6025b + 2301f5a commit 40a7c35

29 files changed

+2496
-46
lines changed

build.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ schedules:
138138
schedule: adhoc
139139
disable_pull_requests: true
140140
branches:
141-
include: ['oss-next']
141+
include: [/oss-next.*/]
142142
env_vars: |
143143
EVENT_LOOP_MANAGER='libev'
144144
EXCLUDE_LONG=1

cassandra/cluster.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2522,7 +2522,8 @@ def execute(self, query, parameters=None, timeout=_NOT_SET, trace=False,
25222522
not yet connected, the query will fail with :class:`NoHostAvailable`. Using this is
25232523
discouraged except in a few cases, e.g., querying node-local tables and applying schema changes.
25242524
2525-
`execute_as` the user that will be used on the server to execute the request.
2525+
`execute_as` the user that will be used on the server to execute the request. This is only available
2526+
on a DSE cluster.
25262527
"""
25272528

25282529
return self.execute_async(query, parameters, trace, custom_payload, timeout, execution_profile, paging_state, host, execute_as).result()
Lines changed: 224 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,224 @@
1+
# Copyright DataStax, Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import logging
16+
import copy
17+
18+
HAVE_GREMLIN = False
19+
try:
20+
import gremlin_python
21+
HAVE_GREMLIN = True
22+
except ImportError:
23+
# gremlinpython is not installed.
24+
pass
25+
26+
if HAVE_GREMLIN:
27+
from gremlin_python.structure.graph import Graph
28+
from gremlin_python.driver.remote_connection import RemoteConnection, RemoteTraversal
29+
from gremlin_python.process.traversal import Traverser, TraversalSideEffects
30+
from gremlin_python.process.graph_traversal import GraphTraversal
31+
32+
from cassandra.cluster import Session, GraphExecutionProfile, EXEC_PROFILE_GRAPH_DEFAULT
33+
from cassandra.datastax.graph import GraphOptions, GraphProtocol
34+
35+
from cassandra.datastax.graph.fluent.serializers import (
36+
GremlinGraphSONReader,
37+
deserializers,
38+
gremlin_deserializers
39+
)
40+
from cassandra.datastax.graph.fluent.query import _DefaultTraversalBatch, _query_from_traversal
41+
42+
log = logging.getLogger(__name__)
43+
44+
__all__ = ['BaseGraphRowFactory', 'dse_graphson_reader', 'graphson_reader', 'graph_traversal_row_factory',
45+
'graph_traversal_dse_object_row_factory', 'DSESessionRemoteGraphConnection', 'DseGraph']
46+
47+
# Create our custom GraphSONReader/Writer
48+
dse_graphson_reader = GremlinGraphSONReader(deserializer_map=deserializers)
49+
graphson_reader = GremlinGraphSONReader(deserializer_map=gremlin_deserializers)
50+
51+
# Traversal result keys
52+
_bulk_key = 'bulk'
53+
_result_key = 'result'
54+
55+
56+
class BaseGraphRowFactory(object):
57+
"""
58+
Base row factory for graph traversal. This class basically wraps a
59+
graphson reader function to handle additional features of Gremlin/DSE
60+
and is callable as a normal row factory.
61+
62+
Currently supported:
63+
- bulk results
64+
65+
:param graphson_reader: The function used to read the graphson.
66+
67+
Use example::
68+
69+
my_custom_row_factory = BaseGraphRowFactory(custom_graphson_reader.readObject)
70+
"""
71+
72+
def __init__(self, graphson_reader):
73+
self._graphson_reader = graphson_reader
74+
75+
def __call__(self, column_names, rows):
76+
results = []
77+
78+
for row in rows:
79+
parsed_row = self._graphson_reader(row[0])
80+
bulk = parsed_row.get(_bulk_key, 1)
81+
if bulk > 1: # Avoid deepcopy call if bulk <= 1
82+
results.extend([copy.deepcopy(parsed_row[_result_key])
83+
for _ in range(bulk - 1)])
84+
85+
results.append(parsed_row[_result_key])
86+
87+
return results
88+
89+
90+
graph_traversal_row_factory = BaseGraphRowFactory(graphson_reader.readObject)
91+
graph_traversal_row_factory.__doc__ = "Row Factory that returns the decoded graphson."
92+
93+
graph_traversal_dse_object_row_factory = BaseGraphRowFactory(dse_graphson_reader.readObject)
94+
graph_traversal_dse_object_row_factory.__doc__ = "Row Factory that returns the decoded graphson as DSE types."
95+
96+
97+
class DSESessionRemoteGraphConnection(RemoteConnection):
98+
"""
99+
A Tinkerpop RemoteConnection to execute traversal queries on DSE.
100+
101+
:param session: A DSE session
102+
:param graph_name: (Optional) DSE Graph name.
103+
:param execution_profile: (Optional) Execution profile for traversal queries. Default is set to `EXEC_PROFILE_GRAPH_DEFAULT`.
104+
"""
105+
106+
session = None
107+
graph_name = None
108+
execution_profile = None
109+
110+
def __init__(self, session, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT):
111+
super(DSESessionRemoteGraphConnection, self).__init__(None, None)
112+
113+
if not isinstance(session, Session):
114+
raise ValueError('A DSE Session must be provided to execute graph traversal queries.')
115+
116+
self.session = session
117+
self.graph_name = graph_name
118+
self.execution_profile = execution_profile
119+
120+
def submit(self, bytecode):
121+
122+
query = DseGraph.query_from_traversal(bytecode)
123+
ep = self.session.execution_profile_clone_update(self.execution_profile,
124+
row_factory=graph_traversal_row_factory)
125+
graph_options = ep.graph_options.copy()
126+
graph_options.graph_language = DseGraph.DSE_GRAPH_QUERY_LANGUAGE
127+
if self.graph_name:
128+
graph_options.graph_name = self.graph_name
129+
130+
ep.graph_options = graph_options
131+
132+
traversers = self.session.execute_graph(query, execution_profile=ep)
133+
traversers = [Traverser(t) for t in traversers]
134+
return RemoteTraversal(iter(traversers), TraversalSideEffects())
135+
136+
def __str__(self):
137+
return "<DSESessionRemoteGraphConnection: graph_name='{0}'>".format(self.graph_name)
138+
139+
__repr__ = __str__
140+
141+
142+
class DseGraph(object):
143+
"""
144+
Dse Graph utility class for GraphTraversal construction and execution.
145+
"""
146+
147+
DSE_GRAPH_QUERY_LANGUAGE = 'bytecode-json'
148+
"""
149+
Graph query language, Default is 'bytecode-json' (GraphSON).
150+
"""
151+
152+
@staticmethod
153+
def query_from_traversal(traversal):
154+
"""
155+
From a GraphTraversal, return a query string based on the language specified in `DseGraph.DSE_GRAPH_QUERY_LANGUAGE`.
156+
157+
:param traversal: The GraphTraversal object
158+
"""
159+
160+
if isinstance(traversal, GraphTraversal):
161+
for strategy in traversal.traversal_strategies.traversal_strategies:
162+
rc = strategy.remote_connection
163+
if (isinstance(rc, DSESessionRemoteGraphConnection) and
164+
rc.session or rc.graph_name or rc.execution_profile):
165+
log.warning("GraphTraversal session, graph_name and execution_profile are "
166+
"only taken into account when executed with TinkerPop.")
167+
168+
return _query_from_traversal(traversal)
169+
170+
@staticmethod
171+
def traversal_source(session=None, graph_name=None, execution_profile=EXEC_PROFILE_GRAPH_DEFAULT,
172+
traversal_class=None):
173+
"""
174+
Returns a TinkerPop GraphTraversalSource binded to the session and graph_name if provided.
175+
176+
:param session: (Optional) A DSE session
177+
:param graph_name: (Optional) DSE Graph name
178+
:param execution_profile: (Optional) Execution profile for traversal queries. Default is set to `EXEC_PROFILE_GRAPH_DEFAULT`.
179+
:param traversal_class: (Optional) The GraphTraversalSource class to use (DSL).
180+
181+
.. code-block:: python
182+
183+
from cassandra.cluster import Cluster
184+
from cassandra.datastax.graph.fluent import DseGraph
185+
186+
c = Cluster()
187+
session = c.connect()
188+
189+
g = DseGraph.traversal_source(session, 'my_graph')
190+
print g.V().valueMap().toList()
191+
192+
"""
193+
194+
graph = Graph()
195+
traversal_source = graph.traversal(traversal_class)
196+
197+
if session:
198+
traversal_source = traversal_source.withRemote(
199+
DSESessionRemoteGraphConnection(session, graph_name, execution_profile))
200+
201+
return traversal_source
202+
203+
@staticmethod
204+
def create_execution_profile(graph_name):
205+
"""
206+
Creates an ExecutionProfile for GraphTraversal execution. You need to register that execution profile to the
207+
cluster by using `cluster.add_execution_profile`.
208+
209+
:param graph_name: The graph name
210+
"""
211+
212+
ep = GraphExecutionProfile(row_factory=graph_traversal_dse_object_row_factory,
213+
graph_options=GraphOptions(graph_name=graph_name,
214+
graph_language=DseGraph.DSE_GRAPH_QUERY_LANGUAGE,
215+
graph_protocol=GraphProtocol.GRAPHSON_2_0))
216+
return ep
217+
218+
@staticmethod
219+
def batch(*args, **kwargs):
220+
"""
221+
Returns the :class:`cassandra.datastax.graph.fluent.query.TraversalBatch` object allowing to
222+
execute multiple traversals in the same transaction.
223+
"""
224+
return _DefaultTraversalBatch(*args, **kwargs)

0 commit comments

Comments
 (0)