Skip to content

Commit b9872c5

Browse files
Ansgar Lampemadscientist
authored andcommitted
[DB-40533] Add support for new type VECTOR(<dim>,DOUBLE)
Add encode/decode support for new VECTOR datatype. A result column of VECTOR(<dim>, DOUBLE) shows as type datatype.VECTOR_DOUBLE in the metadata description. It is returned as datatype.Vector with subtype datatype.Vector.DOUBLE which can be used as a list. For a parameter to be detected as VECTOR(<dim>,DOUBLE) it also needs to use the datatype.Vector datatype. Alternatively, binding a string is also possible - which currently includes binding a list of numbers which will then be converted to string. This commit also adds tests for the described behavior. Tests require NuoDB 8.0 or better to run.
1 parent 3b1aaba commit b9872c5

File tree

6 files changed

+204
-6
lines changed

6 files changed

+204
-6
lines changed

pynuodb/datatype.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,15 +27,17 @@
2727

2828
__all__ = ['Date', 'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks',
2929
'TimestampFromTicks', 'DateToTicks', 'TimeToTicks',
30-
'TimestampToTicks', 'Binary', 'STRING', 'BINARY', 'NUMBER',
31-
'DATETIME', 'ROWID', 'TypeObjectFromNuodb']
30+
'TimestampToTicks', 'Binary', 'Vector', 'STRING', 'BINARY', 'NUMBER',
31+
'DATETIME', 'ROWID', 'VECTOR_DOUBLE', 'TypeObjectFromNuodb']
3232

3333
import sys
3434
import decimal
3535
from datetime import datetime as Timestamp, date as Date, time as Time
3636
from datetime import timedelta as TimeDelta
3737
from datetime import tzinfo # pylint: disable=unused-import
3838

39+
from pynuodb import protocol
40+
3941
try:
4042
from typing import Tuple, Union # pylint: disable=unused-import
4143
except ImportError:
@@ -279,10 +281,37 @@ def __cmp__(self, other):
279281
return -1
280282

281283

284+
class Vector(list):
285+
"""A specific type for SQL VECTOR(<dim>, <subtype>)
286+
to be able to detect the desired type when binding parameters.
287+
Apart from creating the value as a Vector with subtype
288+
this can be used as a list."""
289+
DOUBLE = protocol.VECTOR_DOUBLE
290+
291+
def __init__(self, subtype, *args, **kwargs):
292+
if args:
293+
if subtype != Vector.DOUBLE:
294+
raise TypeError("Vector type only supported for subtype DOUBLE")
295+
296+
self.subtype = subtype
297+
298+
# forward the remaining arguments to the list __init__
299+
super(Vector, self).__init__(*args, **kwargs)
300+
else:
301+
raise TypeError("Vector needs to be initialized with a subtype like Vector.DOUBLE as"
302+
" first argument")
303+
304+
def getSubtype(self):
305+
# type: () -> int
306+
"""Returns the subtype of vector this instance holds data for"""
307+
return self.subtype
308+
309+
282310
STRING = TypeObject(str)
283311
BINARY = TypeObject(str)
284312
NUMBER = TypeObject(int, decimal.Decimal)
285313
DATETIME = TypeObject(Timestamp, Date, Time)
314+
VECTOR_DOUBLE = TypeObject(list)
286315
ROWID = TypeObject()
287316
NULL = TypeObject(None)
288317

@@ -309,6 +338,7 @@ def __cmp__(self, other):
309338
"timestamp without time zone": DATETIME,
310339
"timestamp with time zone": DATETIME,
311340
"time without time zone": DATETIME,
341+
"vector double": VECTOR_DOUBLE,
312342
# Old types used by NuoDB <2.0.3
313343
"binarystring": BINARY,
314344
"binaryvaryingstring": BINARY,

pynuodb/encodedsession.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,39 @@ def putScaledCount2(self, value):
823823
self.__output += data
824824
return self
825825

826+
def putVectorDouble(self, value):
827+
# type: (datatype.Vector) -> EncodedSession
828+
"""Append a Vector with subtype Vector.DOUBLE to the message.
829+
830+
:type value: datatype.Vector
831+
"""
832+
self.__output.append(protocol.VECTOR)
833+
# subtype
834+
self.__output.append(protocol.VECTOR_DOUBLE)
835+
# length in bytes in count notation, i.e. first
836+
# number of bytes needed for the length, then the
837+
# encoded length
838+
lengthStr = crypt.toByteString(len(value) * 8)
839+
self.__output.append(len(lengthStr))
840+
self.__output += lengthStr
841+
842+
# the actual vector: Each value as double in little endian encoding
843+
for val in value:
844+
self.__output += struct.pack('<d', float(val))
845+
846+
return self
847+
848+
def putVector(self, value):
849+
# type: (datatype.Vector) -> EncodedSession
850+
"""Append a Vector type to the message.
851+
852+
:type value: datatype.Vector
853+
"""
854+
if value.getSubtype() == datatype.Vector.DOUBLE:
855+
return self.putVectorDouble(value)
856+
857+
raise DataError("unsupported value for VECTOR subtype: %d" % (value.getSubtype()))
858+
826859
def putValue(self, value): # pylint: disable=too-many-return-statements
827860
# type: (Any) -> EncodedSession
828861
"""Call the supporting function based on the type of the value."""
@@ -854,6 +887,11 @@ def putValue(self, value): # pylint: disable=too-many-return-statements
854887
if isinstance(value, bool):
855888
return self.putBoolean(value)
856889

890+
# we don't want to autodetect lists as being VECTOR, so we
891+
# only bind double if it is the explicit type
892+
if isinstance(value, datatype.Vector):
893+
return self.putVector(value)
894+
857895
# I find it pretty bogus that we pass str(value) here: why not value?
858896
return self.putString(str(value))
859897

@@ -1096,6 +1134,36 @@ def getUUID(self):
10961134

10971135
raise DataError('Not a UUID')
10981136

1137+
def getVector(self):
1138+
# type: () -> datatype.Vector
1139+
"""Read the next vector off the session.
1140+
1141+
:rtype datatype.Vector
1142+
"""
1143+
if self._getTypeCode() == protocol.VECTOR:
1144+
subtype = crypt.fromByteString(self._takeBytes(1))
1145+
if subtype == protocol.VECTOR_DOUBLE:
1146+
# VECTOR(<dim>, DOUBLE)
1147+
lengthBytes = crypt.fromByteString(self._takeBytes(1))
1148+
length = crypt.fromByteString(self._takeBytes(lengthBytes))
1149+
1150+
if length % 8 != 0:
1151+
raise DataError("Invalid size for VECTOR DOUBLE data: %d" % (length))
1152+
1153+
dimension = length // 8
1154+
1155+
# VECTOR DOUBLE stores the data as little endian
1156+
vector = datatype.Vector(datatype.Vector.DOUBLE,
1157+
[struct.unpack('<d', self._takeBytes(8))[0]
1158+
for _ in range(dimension)])
1159+
1160+
return vector
1161+
else:
1162+
raise DataError("Unknown VECTOR type: %d" % (subtype))
1163+
return 1
1164+
1165+
raise DataError('Not a VECTOR')
1166+
10991167
def getScaledCount2(self):
11001168
# type: () -> decimal.Decimal
11011169
"""Read a scaled and signed decimal from the session.
@@ -1171,6 +1239,9 @@ def getValue(self):
11711239
if code == protocol.UUID:
11721240
return self.getUUID()
11731241

1242+
if code == protocol.VECTOR:
1243+
return self.getVector()
1244+
11741245
if code == protocol.SCALEDCOUNT2:
11751246
return self.getScaledCount2()
11761247

pynuodb/protocol.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
BLOBLEN4 = 193
4646
CLOBLEN0 = 194
4747
CLOBLEN4 = 198
48-
SCALEDCOUNT1 = 199
48+
VECTOR = 199
4949
UUID = 200
5050
SCALEDDATELEN0 = 200
5151
SCALEDDATELEN1 = 201
@@ -66,6 +66,9 @@
6666
DEBUGBARRIER = 240
6767
SCALEDTIMESTAMPNOTZ = 241
6868

69+
# subtypes of the VECTOR type
70+
VECTOR_DOUBLE = 0
71+
6972
# Protocol Messages
7073
FAILURE = 0
7174
OPENDATABASE = 3

tests/conftest.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,20 @@ def database(ap, db, te):
299299
'user': db[1],
300300
'password': db[2],
301301
'options': {'schema': 'test'}} # type: DATABASE_FIXTURE
302+
system_information = {'effective_version': 0}
303+
302304
try:
303305
while True:
304306
try:
305307
conn = pynuodb.connect(**connect_args)
308+
cursor = conn.cursor()
309+
try:
310+
cursor.execute("select GETEFFECTIVEPLATFORMVERSION() from system.dual")
311+
row = cursor.fetchone()
312+
system_information['effective_version'] = row[0]
313+
finally:
314+
cursor.close()
315+
306316
break
307317
except pynuodb.session.SessionException:
308318
pass
@@ -315,4 +325,4 @@ def database(ap, db, te):
315325

316326
_log.info("Database %s is available", db[0])
317327

318-
return connect_args
328+
return {'connect_args': connect_args, 'system_information': system_information}

tests/nuodb_base.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,19 @@ class NuoBase(object):
2525
driver = pynuodb # type: Any
2626

2727
connect_args = ()
28+
system_information = ()
2829
host = None
2930

3031
lower_func = 'lower' # For stored procedure test
3132

3233
@pytest.fixture(autouse=True)
3334
def _setup(self, database):
3435
# Preserve the options we'll need to create a connection to the DB
35-
self.connect_args = database
36+
self.connect_args = database['connect_args']
37+
self.system_information = database['system_information']
3638

3739
# Verify the database is up and has a running TE
38-
dbname = database['database']
40+
dbname = self.connect_args['database']
3941
(ret, out) = nuocmd(['--show-json', 'get', 'processes',
4042
'--db-name', dbname], logout=False)
4143
assert ret == 0, "DB not running: %s" % (out)

tests/nuodb_types_test.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77

88
import decimal
99
import datetime
10+
11+
import pynuodb
12+
1013
from . import nuodb_base
1114
from .mock_tzs import localize
1215

@@ -126,3 +129,82 @@ def test_null_type(self):
126129
assert len(row) == 1
127130
assert cursor.description[0][1] == null_type
128131
assert row[0] is None
132+
133+
def test_vector_type(self):
134+
con = self._connect()
135+
cursor = con.cursor()
136+
137+
# only activate this tests if tested against version 8 or above
138+
if self.system_information['effective_version'] < 1835008:
139+
return
140+
141+
cursor.execute("CREATE TEMPORARY TABLE tmp ("
142+
" vec3 VECTOR(3, DOUBLE),"
143+
" vec5 VECTOR(5, DOUBLE))")
144+
145+
cursor.execute("INSERT INTO tmp VALUES ("
146+
" '[1.1,2.2,33.33]',"
147+
" '[-1,2,-3,4,-5]')")
148+
149+
cursor.execute("SELECT * FROM tmp")
150+
151+
# check metadata
152+
[name, type, _, _, precision, scale, _] = cursor.description[0]
153+
assert name == "VEC3"
154+
assert type == pynuodb.VECTOR_DOUBLE
155+
assert precision == 3
156+
assert scale == 0
157+
158+
[name, type, _, _, precision, scale, _] = cursor.description[1]
159+
assert name == "VEC5"
160+
assert type == pynuodb.VECTOR_DOUBLE
161+
assert precision == 5
162+
assert scale == 0
163+
164+
# check content
165+
row = cursor.fetchone()
166+
assert len(row) == 2
167+
assert row[0] == [1.1, 2.2, 33.33]
168+
assert row[1] == [-1, 2, -3, 4, -5]
169+
assert cursor.fetchone() is None
170+
171+
# check this is actually a Vector type, not just a list
172+
assert isinstance(row[0], pynuodb.Vector)
173+
assert row[0].getSubtype() == pynuodb.Vector.DOUBLE
174+
assert isinstance(row[1], pynuodb.Vector)
175+
assert row[1].getSubtype() == pynuodb.Vector.DOUBLE
176+
177+
# check prepared parameters
178+
parameters = [pynuodb.Vector(pynuodb.Vector.DOUBLE, [11.11, -2.2, 3333.333]),
179+
pynuodb.Vector(pynuodb.Vector.DOUBLE, [-1.23, 2.345, -0.34, 4, -5678.9])]
180+
cursor.execute("TRUNCATE TABLE tmp")
181+
cursor.execute("INSERT INTO tmp VALUES (?, ?)", parameters)
182+
183+
cursor.execute("SELECT * FROM tmp")
184+
185+
# check content
186+
row = cursor.fetchone()
187+
assert len(row) == 2
188+
assert row[0] == parameters[0]
189+
assert row[1] == parameters[1]
190+
assert cursor.fetchone() is None
191+
192+
# check that the inserted values are interpreted correctly by the database
193+
cursor.execute("SELECT CAST(vec3 AS STRING) || ' - ' || CAST(vec5 AS STRING) AS strRep"
194+
" FROM tmp")
195+
196+
row = cursor.fetchone()
197+
assert len(row) == 1
198+
assert row[0] == "[11.11,-2.2,3333.333] - [-1.23,2.345,-0.34,4,-5678.9]"
199+
assert cursor.fetchone() is None
200+
201+
# currently binding a list also works - this is done via implicit string
202+
# conversion of the passed argument in default bind case
203+
parameters = [[11.11, -2.2, 3333.333]]
204+
cursor.execute("SELECT VEC3 = ? FROM tmp", parameters)
205+
206+
# check content
207+
row = cursor.fetchone()
208+
assert len(row) == 1
209+
assert row[0] is True
210+
assert cursor.fetchone() is None

0 commit comments

Comments
 (0)