Skip to content

Commit 49a93a7

Browse files
authored
Merge pull request #456 from dimitri-yatsenko/master
Fix #431, #451, #458, #463, #466
2 parents 53294a4 + 66ad963 commit 49a93a7

23 files changed

+162
-114
lines changed

CHANGELOG.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
11
## Release notes
2-
### 0.10.0 -- work in progress
2+
### 0.10.1 -- Work in progress
3+
* Networkx 2.0 support (#443)
4+
* Sped up queries (#428)
5+
6+
### 0.10.0 -- January 10, 2018
7+
* Deletes are more efficient (#424)
8+
* ERD shows table definition on tooltip hover in Jupyter (#422)
39
* S3 external storage
410
* Garbage collection for external sorage
5-
* Most operators and methods of tables can be invoked as class methods
11+
* Most operators and methods of tables can be invoked as class methods rather than instance methods (#407)
612
* The schema decorator object no longer requires locals() to specify the context
713
* Compatibility with pymysql 0.8.0+
14+
* More efficient loading of dependencies (#403)
815

916
### 0.9.0 -- November 17, 2017
1017
* Made graphviz installation optional

datajoint/__init__.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
'Connection', 'Heading', 'FreeRelation', 'Not', 'schema',
2525
'Manual', 'Lookup', 'Imported', 'Computed', 'Part',
2626
'AndList', 'OrList', 'ERD', 'U',
27+
'DataJointError', 'DuplicateError',
2728
'set_password']
2829

2930

@@ -34,12 +35,6 @@ class key:
3435
pass
3536

3637

37-
class DataJointError(Exception):
38-
"""
39-
Base class for errors specific to DataJoint internal operation.
40-
"""
41-
pass
42-
4338
# ----------- loads local configuration from file ----------------
4439
from .settings import Config, LOCALCONFIG, GLOBALCONFIG, logger, log_levels
4540
config = Config()
@@ -73,6 +68,7 @@ class DataJointError(Exception):
7368
from .schema import Schema as schema
7469
from .erd import ERD
7570
from .admin import set_password, kill
71+
from .errors import DataJointError, DuplicateError
7672

7773

7874
def create_virtual_module(modulename, dbname):

datajoint/autopopulate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from tqdm import tqdm
77
from pymysql import OperationalError
88
from .relational_operand import RelationalOperand, AndList, U
9-
from . import DataJointError
9+
from .errors import DataJointError
1010
from .base_relation import FreeRelation
1111
import signal
1212

datajoint/base_relation.py

Lines changed: 42 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
import logging
88
import warnings
99
from pymysql import OperationalError, InternalError, IntegrityError
10-
from . import config, DataJointError
10+
from . import config
1111
from .declare import declare
1212
from .relational_operand import RelationalOperand
1313
from .blob import pack
1414
from .utils import user_choice
1515
from .heading import Heading
16-
from .settings import server_error_codes
16+
from .errors import server_error_codes, DataJointError, DuplicateError
1717
from . import __version__ as version
1818

1919
logger = logging.getLogger(__name__)
@@ -42,7 +42,12 @@ def heading(self):
4242
if self._heading is None:
4343
self._heading = Heading() # instance-level heading
4444
if not self._heading: # lazy loading of heading
45-
self._heading.init_from_database(self.connection, self.database, self.table_name)
45+
if self.connection is None:
46+
raise DataJointError(
47+
'DataJoint class is missing a database connection. '
48+
'Missing schema decorator on the class? (e.g. @schema)')
49+
else:
50+
self._heading.init_from_database(self.connection, self.database, self.table_name)
4651
return self._heading
4752

4853
@property
@@ -172,7 +177,8 @@ def insert(self, rows, replace=False, skip_duplicates=False, ignore_extra_fields
172177
fields='`' + '`,`'.join(fields) + '`',
173178
table=self.full_table_name,
174179
select=rows.make_sql(select_fields=fields),
175-
duplicate=(' ON DUPLICATE KEY UPDATE `{pk}`=`{pk}`'.format(pk=self.primary_key[0])
180+
duplicate=(' ON DUPLICATE KEY UPDATE `{pk}`={table}.`{pk}`'.format(
181+
table=self.full_table_name, pk=self.primary_key[0])
176182
if skip_duplicates else ''))
177183
self.connection.query(query)
178184
return
@@ -282,10 +288,12 @@ def check_fields(fields):
282288
elif err.args[0] == server_error_codes['unknown column']:
283289
# args[1] -> Unknown column 'extra' in 'field list'
284290
raise DataJointError(
285-
'{} : To ignore extra fields, set ignore_extra_fields=True in insert.'.format(err.args[1])) from None
291+
'{} : To ignore extra fields, set ignore_extra_fields=True in insert.'.format(err.args[1])
292+
) from None
286293
elif err.args[0] == server_error_codes['duplicate entry']:
287-
raise DataJointError(
288-
'{} : To ignore duplicate entries, set skip_duplicates=True in insert.'.format(err.args[1])) from None
294+
raise DuplicateError(
295+
'{} : To ignore duplicate entries, set skip_duplicates=True in insert.'.format(err.args[1])
296+
) from None
289297
else:
290298
raise
291299

@@ -434,11 +442,15 @@ def show_definition(self):
434442
logger.warning('show_definition is deprecated. Use describe instead.')
435443
return self.describe()
436444

437-
def describe(self, printout=True):
445+
def describe(self, context=None, printout=True):
438446
"""
439447
:return: the definition string for the relation using DataJoint DDL.
440448
This does not yet work for aliased foreign keys.
441449
"""
450+
if context is None:
451+
frame = inspect.currentframe().f_back
452+
context = dict(frame.f_globals, **frame.f_locals)
453+
del frame
442454
if self.full_table_name not in self.connection.dependencies:
443455
self.connection.dependencies.load()
444456
parents = self.parents()
@@ -460,14 +472,14 @@ def describe(self, printout=True):
460472
parents.pop(parent_name)
461473
if not parent_name.isdigit():
462474
definition += '-> {class_name}\n'.format(
463-
class_name=lookup_class_name(parent_name, self.context) or parent_name)
475+
class_name=lookup_class_name(parent_name, context) or parent_name)
464476
else:
465477
# aliased foreign key
466478
parent_name = list(self.connection.dependencies.in_edges(parent_name))[0][0]
467479
lst = [(attr, ref) for attr, ref in fk_props['attr_map'].items() if ref != attr]
468480
definition += '({attr_list}) -> {class_name}{ref_list}\n'.format(
469481
attr_list=','.join(r[0] for r in lst),
470-
class_name=lookup_class_name(parent_name, self.context) or parent_name,
482+
class_name=lookup_class_name(parent_name, context) or parent_name,
471483
ref_list=('' if len(attributes_thus_far) - len(attributes_declared) == 1
472484
else '(%s)' % ','.join(r[1] for r in lst)))
473485
attributes_declared.update(fk_props['attr_map'])
@@ -540,25 +552,26 @@ def lookup_class_name(name, context, depth=3):
540552
while nodes:
541553
node = nodes.pop(0)
542554
for member_name, member in node['context'].items():
543-
if inspect.isclass(member) and issubclass(member, BaseRelation):
544-
if member.full_table_name == name: # found it!
545-
return '.'.join([node['context_name'], member_name]).lstrip('.')
546-
try: # look for part tables
547-
parts = member._ordered_class_members
548-
except AttributeError:
549-
pass # not a UserRelation -- cannot have part tables.
550-
else:
551-
for part in (getattr(member, p) for p in parts if p[0].isupper() and hasattr(member, p)):
552-
if inspect.isclass(part) and issubclass(part, BaseRelation) and part.full_table_name == name:
553-
return '.'.join([node['context_name'], member_name, part.__name__]).lstrip('.')
554-
elif node['depth'] > 0 and inspect.ismodule(member) and member.__name__ != 'datajoint':
555-
try:
556-
nodes.append(
557-
dict(context=dict(inspect.getmembers(member)),
558-
context_name=node['context_name'] + '.' + member_name,
559-
depth=node['depth']-1))
560-
except ImportError:
561-
pass # could not import, so do not attempt
555+
if not member_name.startswith('_'): # skip IPython's implicit variables
556+
if inspect.isclass(member) and issubclass(member, BaseRelation):
557+
if member.full_table_name == name: # found it!
558+
return '.'.join([node['context_name'], member_name]).lstrip('.')
559+
try: # look for part tables
560+
parts = member._ordered_class_members
561+
except AttributeError:
562+
pass # not a UserRelation -- cannot have part tables.
563+
else:
564+
for part in (getattr(member, p) for p in parts if p[0].isupper() and hasattr(member, p)):
565+
if inspect.isclass(part) and issubclass(part, BaseRelation) and part.full_table_name == name:
566+
return '.'.join([node['context_name'], member_name, part.__name__]).lstrip('.')
567+
elif node['depth'] > 0 and inspect.ismodule(member) and member.__name__ != 'datajoint':
568+
try:
569+
nodes.append(
570+
dict(context=dict(inspect.getmembers(member)),
571+
context_name=node['context_name'] + '.' + member_name,
572+
depth=node['depth']-1))
573+
except ImportError:
574+
pass # could not import, so do not attempt
562575
return None
563576

564577

datajoint/blob.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
import zlib
66
from collections import OrderedDict, Mapping, Iterable
7+
from decimal import Decimal
78
import numpy as np
8-
from . import DataJointError
9+
from .errors import DataJointError
910

1011
mxClassID = OrderedDict((
1112
# see http://www.mathworks.com/help/techdoc/apiref/mxclassid.html
@@ -223,14 +224,16 @@ def pack_obj(obj):
223224
blob = b''
224225
if isinstance(obj, np.ndarray):
225226
blob += pack_array(obj)
226-
elif isinstance(obj, Mapping): # TODO: check if this is a good inheritance check for dict etc.
227+
elif isinstance(obj, Mapping):
227228
blob += pack_dict(obj)
228229
elif isinstance(obj, str):
229230
blob += pack_array(np.array(obj, dtype=np.dtype('c')))
230231
elif isinstance(obj, Iterable):
231232
blob += pack_array(np.array(list(obj)))
232233
elif isinstance(obj, int) or isinstance(obj, float):
233234
blob += pack_array(np.array(obj))
235+
elif isinstance(obj, Decimal):
236+
blob += pack_array(np.array(np.float64(obj)))
234237
else:
235238
raise DataJointError("Packing object of type %s currently not supported!" % type(obj))
236239

datajoint/connection.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@
77
import pymysql as client
88
import logging
99
from getpass import getpass
10+
from pymysql import err
1011

1112
from . import config
12-
from . import DataJointError
13+
from .errors import DataJointError, server_error_codes
1314
from .dependencies import Dependencies
14-
from pymysql import err
15+
1516

1617
logger = logging.getLogger(__name__)
1718

@@ -94,6 +95,7 @@ def connect(self):
9495
self._conn = client.connect(init_command=self.init_fun,
9596
sql_mode="NO_ZERO_DATE,NO_ZERO_IN_DATE,ERROR_FOR_DIVISION_BY_ZERO,"
9697
"STRICT_TRANS_TABLES,NO_AUTO_CREATE_USER,NO_ENGINE_SUBSTITUTION",
98+
charset=config['connection.charset'],
9799
**self.conn_info)
98100

99101
def register(self, schema):
@@ -118,7 +120,7 @@ def query(self, query, args=(), as_dict=False, suppress_warnings=True):
118120
:param args: additional arguments for the client.cursor
119121
:param as_dict: If as_dict is set to True, the returned cursor objects returns
120122
query results as dictionary.
121-
:param suppress_warning: If True, suppress all warnings arising from underlying query library
123+
:param suppress_warnings: If True, suppress all warnings arising from underlying query library
122124
"""
123125

124126
cursor = client.cursors.DictCursor if as_dict else client.cursors.Cursor
@@ -144,11 +146,12 @@ def query(self, query, args=(), as_dict=False, suppress_warnings=True):
144146
else:
145147
raise
146148
except err.ProgrammingError as e:
147-
raise DataJointError("\n".join((
148-
"Error in query:", query,
149-
"Please check spelling, syntax, and existence of tables and attributes.",
150-
"When restricting a relation by a condition in a string, enclose attributes in backquotes."
151-
)))
149+
if e.args[0] == server_error_codes['parse error']:
150+
raise DataJointError("\n".join((
151+
"Error in query:", query,
152+
"Please check spelling, syntax, and existence of tables and attributes.",
153+
"When restricting a relation by a condition in a string, enclose attributes in backquotes."
154+
))) from None
152155
return cur
153156

154157
def get_user(self):

datajoint/declare.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import pyparsing as pp
77
import logging
88

9-
from . import DataJointError, config
9+
from . import config
10+
from .errors import DataJointError
1011

1112
STORE_NAME_LENGTH = 8
1213
STORE_HASH_LENGTH = 43
@@ -176,7 +177,7 @@ def declare(full_table_name, definition, context):
176177
['PRIMARY KEY (`' + '`,`'.join(primary_key) + '`)'] +
177178
foreign_key_sql +
178179
index_sql) +
179-
'\n) ENGINE=InnoDB, CHARACTER SET latin1, COMMENT "%s"' % table_comment), uses_external
180+
'\n) ENGINE=InnoDB, COMMENT "%s"' % table_comment), uses_external
180181

181182

182183
def compile_attribute(line, in_key, foreign_key_sql):

datajoint/dependencies.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import networkx as nx
22
import itertools
33
from collections import defaultdict
4-
from . import DataJointError
4+
from .errors import DataJointError
55

66

77
class Dependencies(nx.DiGraph):

datajoint/erd.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import functools
44
import io
55
import warnings
6+
import inspect
67
from .base_relation import BaseRelation
78

89
try:
@@ -12,7 +13,8 @@
1213
except:
1314
erd_active = False
1415

15-
from . import Manual, Imported, Computed, Lookup, Part, DataJointError
16+
from . import Manual, Imported, Computed, Lookup, Part
17+
from .errors import DataJointError
1618
from .base_relation import lookup_class_name
1719

1820

@@ -82,15 +84,13 @@ def __init__(self, source, context=None):
8284
super().__init__(source)
8385
return
8486

85-
# get the caller's locals()
87+
# get the caller's context
8688
if context is None:
87-
import inspect
88-
frame = inspect.currentframe()
89-
try:
90-
context = frame.f_back.f_locals
91-
finally:
92-
del frame
93-
self.context = context
89+
frame = inspect.currentframe().f_back
90+
self.context = dict(frame.f_globals, **frame.f_locals)
91+
del frame
92+
else:
93+
self.context = context
9494

9595
# find connection in the source
9696
try:
@@ -209,10 +209,7 @@ def _make_graph(self):
209209
graph = nx.DiGraph(nx.DiGraph(self).subgraph(nodes))
210210
nx.set_node_attributes(graph, name='node_type', values={n: _get_tier(n) for n in graph})
211211
# relabel nodes to class names
212-
clean_context = dict((k, v) for k, v in self.context.items()
213-
if not k.startswith('_')) # exclude ipython's implicit variables
214-
mapping = {node: (lookup_class_name(node, clean_context) or node)
215-
for node in graph.nodes()}
212+
mapping = {node: lookup_class_name(node, self.context) or node for node in graph.nodes()}
216213
new_names = [mapping.values()]
217214
if len(new_names) > len(set(new_names)):
218215
raise DataJointError('Some classes have identical names. The ERD cannot be plotted.')
@@ -258,7 +255,7 @@ def make_dot(self):
258255
if name.split('.')[0] in self.context:
259256
cls = eval(name, self.context)
260257
assert(issubclass(cls, BaseRelation))
261-
description = cls().describe(printout=False).split('\n')
258+
description = cls().describe(context=self.context, printout=False).split('\n')
262259
description = (
263260
'-'*30 if q.startswith('---') else q.replace('->', '→') if '->' in q else q.split(':')[0]
264261
for q in description if not q.startswith('#'))

datajoint/errors.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
server_error_codes = {
2+
'unknown column': 1054,
3+
'duplicate entry': 1062,
4+
'parse error': 1064,
5+
'command denied': 1142,
6+
'table does not exist': 1146,
7+
'syntax error': 1149
8+
}
9+
10+
11+
class DataJointError(Exception):
12+
"""
13+
Base class for errors specific to DataJoint internal operation.
14+
"""
15+
pass
16+
17+
18+
class DuplicateError(DataJointError):
19+
"""
20+
Error caused by a violation of a unique constraint when inserting data
21+
"""
22+
pass

0 commit comments

Comments
 (0)