Skip to content

Commit 531a970

Browse files
author
Chris Turner
committed
Merge remote-tracking branch 'upstream/release013' into issue-829
2 parents 204e6e3 + e1f24c5 commit 531a970

26 files changed

+488
-189
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
## Release notes
22

33
### Current
4+
=======
5+
### 0.13.0 -- TBD
6+
* Support DataJoint datatype and connection plugins (#715, #729) PR 730, #735
7+
* Allow updating specified secondary attributes using `update1` PR #763
48
* Remove python 3.5 support
59

610
### 0.12.8 -- Jan 12, 2021

datajoint.pub

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
-----BEGIN PUBLIC KEY-----
2+
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDUMOo2U7YQ1uOrKU/IreM3AQP2
3+
AXJC3au+S9W+dilxHcJ3e98bRVqrFeOofcGeRPoNc38fiLmLDUiBskJeVrpm29Wo
4+
AkH6yhZWk1o8NvGMhK4DLsJYlsH6tZuOx9NITKzJuOOH6X1I5Ucs7NOSKnmu7g5g
5+
WTT5kCgF5QAe5JN8WQIDAQAB
6+
-----END PUBLIC KEY-----

datajoint/attribute_adapter.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
from .errors import DataJointError, _support_adapted_types
3+
from .plugin import type_plugins
34

45

56
class AttributeAdapter:
@@ -38,10 +39,11 @@ def get_adapter(context, adapter_name):
3839
raise DataJointError('Support for Adapted Attribute types is disabled.')
3940
adapter_name = adapter_name.lstrip('<').rstrip('>')
4041
try:
41-
adapter = context[adapter_name]
42+
adapter = (context[adapter_name] if adapter_name in context
43+
else type_plugins[adapter_name]['object'].load())
4244
except KeyError:
4345
raise DataJointError(
44-
"Attribute adapter '{adapter_name}' is not defined.".format(adapter_name=adapter_name)) from None
46+
"Attribute adapter '{adapter_name}' is not defined.".format(adapter_name=adapter_name))
4547
if not isinstance(adapter, AttributeAdapter):
4648
raise DataJointError(
4749
"Attribute adapter '{adapter_name}' must be an instance of datajoint.AttributeAdapter".format(

datajoint/autopopulate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def _rename_attributes(table, props):
4141
parents = self.target.parents(primary=True, as_objects=True, foreign_key_info=True)
4242
if not parents:
4343
raise DataJointError(
44-
'A relation must have primary dependencies for auto-populate to work') from None
44+
'A relation must have primary dependencies for auto-populate to work')
4545
self._key_source = _rename_attributes(*parents[0])
4646
for q in parents[1:]:
4747
self._key_source *= _rename_attributes(*q)

datajoint/connection.py

Lines changed: 38 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,36 @@
1111
from .settings import config
1212
from . import errors
1313
from .dependencies import Dependencies
14+
from .plugin import connection_plugins
1415

1516
logger = logging.getLogger(__name__)
1617
query_log_max_length = 300
1718

1819

20+
def get_host_hook(host_input):
21+
if '://' in host_input:
22+
plugin_name = host_input.split('://')[0]
23+
try:
24+
return connection_plugins[plugin_name]['object'].load().get_host(host_input)
25+
except KeyError:
26+
raise errors.DataJointError(
27+
"Connection plugin '{}' not found.".format(plugin_name))
28+
else:
29+
return host_input
30+
31+
32+
def connect_host_hook(connection_obj):
33+
if '://' in connection_obj.conn_info['host_input']:
34+
plugin_name = connection_obj.conn_info['host_input'].split('://')[0]
35+
try:
36+
connection_plugins[plugin_name]['object'].load().connect_host(connection_obj)
37+
except KeyError:
38+
raise errors.DataJointError(
39+
"Connection plugin '{}' not found.".format(plugin_name))
40+
else:
41+
connection_obj.connect()
42+
43+
1944
def translate_query_error(client_error, query):
2045
"""
2146
Take client error and original query and return the corresponding DataJoint exception.
@@ -76,7 +101,8 @@ def conn(host=None, user=None, password=None, *, init_fun=None, reset=False, use
76101
#encrypted-connection-options).
77102
"""
78103
if not hasattr(conn, 'connection') or reset:
79-
host = host if host is not None else config['database.host']
104+
host_input = host if host is not None else config['database.host']
105+
host = get_host_hook(host_input)
80106
user = user if user is not None else config['database.user']
81107
password = password if password is not None else config['database.password']
82108
if user is None: # pragma: no cover
@@ -85,7 +111,8 @@ def conn(host=None, user=None, password=None, *, init_fun=None, reset=False, use
85111
password = getpass(prompt="Please enter DataJoint password: ")
86112
init_fun = init_fun if init_fun is not None else config['connection.init_function']
87113
use_tls = use_tls if use_tls is not None else config['database.use_tls']
88-
conn.connection = Connection(host, user, password, None, init_fun, use_tls)
114+
conn.connection = Connection(host, user, password, None, init_fun, use_tls,
115+
host_input=host_input)
89116
return conn.connection
90117

91118

@@ -104,7 +131,8 @@ class Connection:
104131
:param use_tls: TLS encryption option
105132
"""
106133

107-
def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None):
134+
def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None,
135+
host_input=None):
108136
if ':' in host:
109137
# the port in the hostname overrides the port argument
110138
host, port = host.split(':')
@@ -115,10 +143,11 @@ def __init__(self, host, user, password, port=None, init_fun=None, use_tls=None)
115143
if use_tls is not False:
116144
self.conn_info['ssl'] = use_tls if isinstance(use_tls, dict) else {'ssl': {}}
117145
self.conn_info['ssl_input'] = use_tls
146+
self.conn_info['host_input'] = host_input
118147
self.init_fun = init_fun
119148
print("Connecting {user}@{host}:{port}".format(**self.conn_info))
120149
self._conn = None
121-
self.connect()
150+
connect_host_hook(self)
122151
if self.is_connected:
123152
logger.info("Connected {user}@{host}:{port}".format(**self.conn_info))
124153
self.connection_id = self.query('SELECT connection_id()').fetchone()[0]
@@ -149,15 +178,15 @@ def connect(self):
149178
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
150179
charset=config['connection.charset'],
151180
**{k: v for k, v in self.conn_info.items()
152-
if k != 'ssl_input'})
181+
if k not in ['ssl_input', 'host_input']})
153182
except client.err.InternalError:
154183
self._conn = client.connect(
155184
init_command=self.init_fun,
156185
sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
157186
"STRICT_ALL_TABLES,NO_ENGINE_SUBSTITUTION",
158187
charset=config['connection.charset'],
159188
**{k: v for k, v in self.conn_info.items()
160-
if not(k == 'ssl_input' or
189+
if not(k in ['ssl_input', 'host_input'] or
161190
k == 'ssl' and self.conn_info['ssl_input'] is None)})
162191
self._conn.autocommit(True)
163192

@@ -194,7 +223,7 @@ def _execute_query(cursor, query, args, cursor_class, suppress_warnings):
194223
warnings.simplefilter("ignore")
195224
cursor.execute(query, args)
196225
except client.err.Error as err:
197-
raise translate_query_error(err, query) from None
226+
raise translate_query_error(err, query)
198227

199228
def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconnect=None):
200229
"""
@@ -217,10 +246,10 @@ def query(self, query, args=(), *, as_dict=False, suppress_warnings=True, reconn
217246
if not reconnect:
218247
raise
219248
warnings.warn("MySQL server has gone away. Reconnecting to the server.")
220-
self.connect()
249+
connect_host_hook(self)
221250
if self._in_transaction:
222251
self.cancel_transaction()
223-
raise errors.LostConnectionError("Connection was lost during a transaction.") from None
252+
raise errors.LostConnectionError("Connection was lost during a transaction.")
224253
logger.debug("Re-executing")
225254
cursor = self._conn.cursor(cursor=cursor_class)
226255
self._execute_query(cursor, query, args, cursor_class, suppress_warnings)

datajoint/declare.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def match_type(attribute_type):
4545
try:
4646
return next(category for category, pattern in TYPE_PATTERN.items() if pattern.match(attribute_type))
4747
except StopIteration:
48-
raise DataJointError("Unsupported attribute type {type}".format(type=attribute_type)) from None
48+
raise DataJointError("Unsupported attribute type {type}".format(type=attribute_type))
4949

5050

5151
logger = logging.getLogger(__name__)
@@ -135,7 +135,7 @@ def compile_foreign_key(line, context, attributes, primary_key, attr_sql, foreig
135135
try:
136136
result = foreign_key_parser_old.parseString(line)
137137
except pp.ParseBaseException as err:
138-
raise DataJointError('Parsing error in line "%s". %s.' % (line, err)) from None
138+
raise DataJointError('Parsing error in line "%s". %s.' % (line, err))
139139
else:
140140
obsolete = True
141141
try:
@@ -430,7 +430,7 @@ def compile_attribute(line, in_key, foreign_key_sql, context):
430430
match = attribute_parser.parseString(line + '#', parseAll=True)
431431
except pp.ParseException as err:
432432
raise DataJointError('Declaration error in position {pos} in line:\n {line}\n{msg}'.format(
433-
line=err.args[0], pos=err.args[1], msg=err.args[2])) from None
433+
line=err.args[0], pos=err.args[1], msg=err.args[2]))
434434
match['comment'] = match['comment'].rstrip('#')
435435
if 'default' not in match:
436436
match['default'] = ''

datajoint/errors.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,24 @@
55
import os
66

77

8+
# --- Unverified Plugin Check ---
9+
class PluginWarning(Exception):
10+
pass
11+
12+
813
# --- Top Level ---
914
class DataJointError(Exception):
1015
"""
1116
Base class for errors specific to DataJoint internal operation.
1217
"""
18+
def __init__(self, *args):
19+
from .plugin import connection_plugins, type_plugins
20+
self.__cause__ = PluginWarning(
21+
'Unverified DataJoint plugin detected.') if any([any(
22+
[not plugins[k]['verified'] for k in plugins])
23+
for plugins in [connection_plugins, type_plugins]
24+
if plugins]) else None
25+
1326
def suggest(self, *args):
1427
"""
1528
regenerate the exception with additional arguments

datajoint/expression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ def prep_value(k, v):
129129
try:
130130
v = uuid.UUID(v)
131131
except (AttributeError, ValueError):
132-
raise DataJointError('Badly formed UUID {v} in restriction by `{k}`'.format(k=k, v=v)) from None
132+
raise DataJointError('Badly formed UUID {v} in restriction by `{k}`'.format(k=k, v=v))
133133
return "X'%s'" % binascii.hexlify(v.bytes).decode()
134134
if isinstance(v, (datetime.date, datetime.datetime, datetime.time, decimal.Decimal)):
135135
return '"%s"' % v

datajoint/external.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def upload_filepath(self, local_filepath):
230230
relative_filepath = str(local_filepath.relative_to(self.spec['stage']).as_posix())
231231
except ValueError:
232232
raise DataJointError('The path {path} is not in stage {stage}'.format(
233-
path=local_filepath.parent, **self.spec)) from None
233+
path=local_filepath.parent, **self.spec))
234234
uuid = uuid_from_buffer(init_string=relative_filepath) # hash relative path, not contents
235235
contents_hash = uuid_from_file(local_filepath)
236236

datajoint/heading.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,8 @@ def init_from_database(self, conn, database, table_name, context):
250250
url = "https://docs.datajoint.io/python/admin/5-blob-config.html" \
251251
"#migration-between-datajoint-v0-11-and-v0-12"
252252
raise DataJointError('Legacy datatype `{type}`. Migrate your external stores to '
253-
'datajoint 0.12: {url}'.format(url=url, **attr)) from None
254-
raise DataJointError('Unknown attribute type `{type}`'.format(**attr)) from None
253+
'datajoint 0.12: {url}'.format(url=url, **attr))
254+
raise DataJointError('Unknown attribute type `{type}`'.format(**attr))
255255
if category == 'FILEPATH' and not _support_filepath_types():
256256
raise DataJointError("""
257257
The filepath data type is disabled until complete validation.

0 commit comments

Comments
 (0)