Skip to content

Commit 2db3728

Browse files
authored
PYTHON-1352 Add vector type, codec + support for parsing CQL type (#1161)
1 parent 643d3a6 commit 2db3728

File tree

3 files changed

+56
-5
lines changed

3 files changed

+56
-5
lines changed

cassandra/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def emit(self, record):
2222

2323
logging.getLogger('cassandra').addHandler(NullHandler())
2424

25-
__version_info__ = (3, 27, 0)
25+
__version_info__ = (3, 28, 0b1)
2626
__version__ = '.'.join(map(str, __version_info__))
2727

2828

cassandra/cqltypes.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,13 +235,15 @@ def parse_casstype_args(typestring):
235235
else:
236236
names.append(None)
237237

238-
ctype = lookup_casstype_simple(tok)
238+
try:
239+
ctype = int(tok)
240+
except ValueError:
241+
ctype = lookup_casstype_simple(tok)
239242
types.append(ctype)
240243

241244
# return the first (outer) type, which will have all parameters applied
242245
return args[0][0][0]
243246

244-
245247
def lookup_casstype(casstype):
246248
"""
247249
Given a Cassandra type as a string (possibly including parameters), hand
@@ -259,6 +261,7 @@ def lookup_casstype(casstype):
259261
try:
260262
return parse_casstype_args(casstype)
261263
except (ValueError, AssertionError, IndexError) as e:
264+
log.debug("Exception in parse_casstype_args: %s" % e)
262265
raise ValueError("Don't know how to parse type string %r: %s" % (casstype, e))
263266

264267

@@ -296,7 +299,7 @@ class _CassandraType(object):
296299
"""
297300

298301
def __repr__(self):
299-
return '<%s( %r )>' % (self.cql_parameterized_type(), self.val)
302+
return '<%s>' % (self.cql_parameterized_type())
300303

301304
@classmethod
302305
def from_binary(cls, byts, protocol_version):
@@ -1423,3 +1426,31 @@ def serialize(cls, v, protocol_version):
14231426
buf.write(int8_pack(cls._encode_precision(bound.precision)))
14241427

14251428
return buf.getvalue()
1429+
1430+
class VectorType(_CassandraType):
1431+
typename = 'org.apache.cassandra.db.marshal.VectorType'
1432+
vector_size = 0
1433+
subtype = None
1434+
1435+
@classmethod
1436+
def apply_parameters(cls, params, names):
1437+
assert len(params) == 2
1438+
subtype = lookup_casstype(params[0])
1439+
vsize = params[1]
1440+
return type('%s(%s)' % (cls.cass_parameterized_type_with([]), vsize), (cls,), {'vector_size': vsize, 'subtype': subtype})
1441+
1442+
@classmethod
1443+
def deserialize(cls, byts, protocol_version):
1444+
indexes = (4 * x for x in range(0, cls.vector_size))
1445+
return [cls.subtype.deserialize(byts[idx:idx + 4], protocol_version) for idx in indexes]
1446+
1447+
@classmethod
1448+
def serialize(cls, v, protocol_version):
1449+
buf = io.BytesIO()
1450+
for item in v:
1451+
buf.write(cls.subtype.serialize(item, protocol_version))
1452+
return buf.getvalue()
1453+
1454+
@classmethod
1455+
def cql_parameterized_type(cls):
1456+
return "%s<%s, %s>" % (cls.typename, cls.subtype.typename, cls.vector_size)

tests/unit/test_types.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
EmptyValue, LongType, SetType, UTF8Type,
2828
cql_typename, int8_pack, int64_pack, lookup_casstype,
2929
lookup_casstype_simple, parse_casstype_args,
30-
int32_pack, Int32Type, ListType, MapType
30+
int32_pack, Int32Type, ListType, MapType, VectorType,
31+
FloatType
3132
)
3233
from cassandra.encoder import cql_quote
3334
from cassandra.pool import Host
@@ -190,6 +191,12 @@ class BarType(FooType):
190191
self.assertEqual(UTF8Type, ctype.subtypes[2])
191192
self.assertEqual([b'city', None, b'zip'], ctype.names)
192193

194+
def test_parse_casstype_vector(self):
195+
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 3)")
196+
self.assertTrue(issubclass(ctype, VectorType))
197+
self.assertEqual(3, ctype.vector_size)
198+
self.assertEqual(FloatType, ctype.subtype)
199+
193200
def test_empty_value(self):
194201
self.assertEqual(str(EmptyValue()), 'EMPTY')
195202

@@ -303,6 +310,19 @@ def test_cql_quote(self):
303310
self.assertEqual(cql_quote('test'), "'test'")
304311
self.assertEqual(cql_quote(0), '0')
305312

313+
def test_vector_round_trip(self):
314+
base = [3.4, 2.9, 41.6, 12.0]
315+
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")
316+
base_bytes = ctype.serialize(base, 0)
317+
self.assertEqual(16, len(base_bytes))
318+
result = ctype.deserialize(base_bytes, 0)
319+
self.assertEqual(len(base), len(result))
320+
for idx in range(0,len(base)):
321+
self.assertAlmostEqual(base[idx], result[idx], places=5)
322+
323+
def test_vector_cql_parameterized_type(self):
324+
ctype = parse_casstype_args("org.apache.cassandra.db.marshal.VectorType(org.apache.cassandra.db.marshal.FloatType, 4)")
325+
self.assertEqual(ctype.cql_parameterized_type(), "org.apache.cassandra.db.marshal.VectorType<float, 4>")
306326

307327
ZERO = datetime.timedelta(0)
308328

0 commit comments

Comments
 (0)