Skip to content

Commit 322f17f

Browse files
Merge pull request #872 from dimitri-yatsenko/cascade-delete
Cleanup `schema.activate` logic
2 parents 06e441c + 65df8cd commit 322f17f

File tree

6 files changed

+109
-82
lines changed

6 files changed

+109
-82
lines changed

datajoint/schemas.py

Lines changed: 68 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import warnings
2-
import pymysql
32
import logging
43
import inspect
54
import re
@@ -8,7 +7,7 @@
87
from .connection import conn
98
from .diagram import Diagram, _get_tier
109
from .settings import config
11-
from .errors import DataJointError
10+
from .errors import DataJointError, AccessError
1211
from .jobs import JobTable
1312
from .external import ExternalMapping
1413
from .heading import Heading
@@ -45,9 +44,11 @@ class Schema:
4544
def __init__(self, schema_name=None, context=None, *, connection=None, create_schema=True,
4645
create_tables=True, add_objects=None):
4746
"""
48-
Associate database schema `schema_name`. If the schema does not exist, attempt to create it on the server.
47+
Associate database schema `schema_name`. If the schema does not exist, attempt to
48+
create it on the server.
4949
50-
If the schema_name is omitted, then schema.activate(..) must be called later to associate with the database.
50+
If the schema_name is omitted, then schema.activate(..) must be called later
51+
to associate with the database.
5152
5253
:param schema_name: the database schema to associate.
5354
:param context: dictionary for looking up foreign key references, leave None to use local context.
@@ -70,25 +71,32 @@ def __init__(self, schema_name=None, context=None, *, connection=None, create_sc
7071
if schema_name:
7172
self.activate(schema_name)
7273

73-
def activate(self, schema_name, *, connection=None, create_schema=True,
74+
def is_activated(self):
75+
return self.database is not None
76+
77+
def activate(self, schema_name=None, *, connection=None, create_schema=None,
7478
create_tables=None, add_objects=None):
7579
"""
76-
Associate database schema `schema_name`. If the schema does not exist, attempt to create it on the server.
80+
Associate database schema `schema_name`. If the schema does not exist, attempt to
81+
create it on the server.
7782
:param schema_name: the database schema to associate.
83+
schema_name=None is used to assert that the schema has already been activated.
7884
:param connection: Connection object. Defaults to datajoint.conn().
79-
:param create_schema: When False, do not create the schema and raise an error if missing.
80-
:param create_tables: When False, do not create tables and raise errors when accessing missing tables.
81-
:param add_objects: a mapping with additional objects to make available to the context in which table classes
82-
are declared.
85+
:param create_schema: If False, do not create the schema and raise an error if missing.
86+
:param create_tables: If False, do not create tables and raise errors when attempting
87+
to access missing tables.
88+
:param add_objects: a mapping with additional objects to make available to the context
89+
in which table classes are declared.
8390
"""
8491
if schema_name is None:
85-
if self.is_activated:
92+
if self.exists:
8693
return
8794
raise DataJointError("Please provide a schema_name to activate the schema.")
88-
if self.is_activated:
95+
if self.database is not None and self.exists:
8996
if self.database == schema_name: # already activated
9097
return
91-
raise DataJointError("The schema is already activated for schema {db}.".format(db=self.database))
98+
raise DataJointError(
99+
"The schema is already activated for schema {db}.".format(db=self.database))
92100
if connection is not None:
93101
self.connection = connection
94102
if self.connection is None:
@@ -101,37 +109,32 @@ def activate(self, schema_name, *, connection=None, create_schema=True,
101109
if add_objects:
102110
self.add_objects = add_objects
103111
if not self.exists:
104-
if not create_schema or not self.database:
112+
if not self.create_schema or not self.database:
105113
raise DataJointError(
106-
"Database named `{name}` was not defined. "
114+
"Database `{name}` has not yet been declared. "
107115
"Set argument create_schema=True to create it.".format(name=schema_name))
116+
# create database
117+
logger.info("Creating schema `{name}`.".format(name=schema_name))
118+
try:
119+
self.connection.query("CREATE DATABASE `{name}`".format(name=schema_name))
120+
except AccessError:
121+
raise DataJointError(
122+
"Schema `{name}` does not exist and could not be created. "
123+
"Check permissions.".format(name=schema_name))
108124
else:
109-
# create database
110-
logger.info("Creating schema `{name}`.".format(name=schema_name))
111-
try:
112-
self.connection.query("CREATE DATABASE `{name}`".format(name=schema_name))
113-
except pymysql.OperationalError:
114-
raise DataJointError(
115-
"Schema `{name}` does not exist and could not be created. "
116-
"Check permissions.".format(name=schema_name))
117-
else:
118-
self.log('created')
119-
self.log('connect')
125+
self.log('created')
120126
self.connection.register(self)
121127

122-
# decorate all tables
128+
# decorate all tables already decorated
123129
for cls, context in self.declare_list:
124130
if self.add_objects:
125131
context = dict(context, **self.add_objects)
126132
self._decorate_master(cls, context)
127133

128-
@property
129-
def is_activated(self):
130-
return self.database is not None
131-
132-
def _assert_activation(self, message="The schema must be activated first."):
133-
if not self.is_activated:
134-
raise DataJointError(message)
134+
def _assert_exists(self, message=None):
135+
if not self.exists:
136+
raise DataJointError(
137+
message or "Schema `{db}` has not been created.".format(db=self.database))
135138

136139
def __call__(self, cls, *, context=None):
137140
"""
@@ -142,7 +145,7 @@ def __call__(self, cls, *, context=None):
142145
context = context or self.context or inspect.currentframe().f_back.f_locals
143146
if issubclass(cls, Part):
144147
raise DataJointError('The schema decorator should not be applied to Part relations')
145-
if self.is_activated:
148+
if self.is_activated():
146149
self._decorate_master(cls, context)
147150
else:
148151
self.declare_list.append((cls, context))
@@ -204,7 +207,7 @@ def _decorate_table(self, table_class, context, assert_declared=False):
204207

205208
@property
206209
def log(self):
207-
self._assert_activation()
210+
self._assert_exists()
208211
if self._log is None:
209212
self._log = Log(self.connection, self.database)
210213
return self._log
@@ -217,7 +220,7 @@ def size_on_disk(self):
217220
"""
218221
:return: size of the entire schema in bytes
219222
"""
220-
self._assert_activation()
223+
self._assert_exists()
221224
return int(self.connection.query(
222225
"""
223226
SELECT SUM(data_length + index_length)
@@ -230,7 +233,7 @@ def spawn_missing_classes(self, context=None):
230233
in the context.
231234
:param context: alternative context to place the missing classes into, e.g. locals()
232235
"""
233-
self._assert_activation()
236+
self._assert_exists()
234237
if context is None:
235238
if self.context is not None:
236239
context = self.context
@@ -273,43 +276,49 @@ def drop(self, force=False):
273276
"""
274277
Drop the associated schema if it exists
275278
"""
276-
self._assert_activation()
277279
if not self.exists:
278-
logger.info("Schema named `{database}` does not exist. Doing nothing.".format(database=self.database))
280+
logger.info("Schema named `{database}` does not exist. Doing nothing.".format(
281+
database=self.database))
279282
elif (not config['safemode'] or
280283
force or
281284
user_choice("Proceed to delete entire schema `%s`?" % self.database, default='no') == 'yes'):
282285
logger.info("Dropping `{database}`.".format(database=self.database))
283286
try:
284287
self.connection.query("DROP DATABASE `{database}`".format(database=self.database))
285288
logger.info("Schema `{database}` was dropped successfully.".format(database=self.database))
286-
except pymysql.OperationalError:
287-
raise DataJointError("An attempt to drop schema `{database}` "
288-
"has failed. Check permissions.".format(database=self.database))
289+
except AccessError:
290+
raise AccessError(
291+
"An attempt to drop schema `{database}` "
292+
"has failed. Check permissions.".format(database=self.database))
289293

290294
@property
291295
def exists(self):
292296
"""
293297
:return: true if the associated schema exists on the server
294298
"""
295-
self._assert_activation()
296-
cur = self.connection.query("SHOW DATABASES LIKE '{database}'".format(database=self.database))
297-
return cur.rowcount > 0
299+
if self.database is None:
300+
raise DataJointError("Schema must be activated first.")
301+
return self.database is not None and (
302+
self.connection.query(
303+
"SELECT schema_name "
304+
"FROM information_schema.schemata "
305+
"WHERE schema_name = '{database}'".format(
306+
database=self.database)).rowcount > 0)
298307

299308
@property
300309
def jobs(self):
301310
"""
302311
schema.jobs provides a view of the job reservation table for the schema
303312
:return: jobs table
304313
"""
305-
self._assert_activation()
314+
self._assert_exists()
306315
if self._jobs is None:
307316
self._jobs = JobTable(self.connection, self.database)
308317
return self._jobs
309318

310319
@property
311320
def code(self):
312-
self._assert_activation()
321+
self._assert_exists()
313322
return self.save()
314323

315324
def save(self, python_filename=None):
@@ -318,7 +327,7 @@ def save(self, python_filename=None):
318327
This method is in preparation for a future release and is not officially supported.
319328
:return: a string containing the body of a complete Python module defining this schema.
320329
"""
321-
self._assert_activation()
330+
self._assert_exists()
322331
module_count = itertools.count()
323332
# add virtual modules for referenced modules with names vmod0, vmod1, ...
324333
module_lookup = collections.defaultdict(lambda: 'vmod' + str(next(module_count)))
@@ -358,20 +367,18 @@ def replace(s):
358367
for k, v in module_lookup.items()), body))
359368
if python_filename is None:
360369
return python_code
361-
else:
362-
with open(python_filename, 'wt') as f:
363-
f.write(python_code)
370+
with open(python_filename, 'wt') as f:
371+
f.write(python_code)
364372

365373
def list_tables(self):
366374
"""
367375
Return a list of all tables in the schema except tables with ~ in first character such
368376
as ~logs and ~job
369-
:return: A list of table names in their raw datajoint naming convection form
377+
:return: A list of table names from the database schema.
370378
"""
371-
372379
return [table_name for (table_name,) in self.connection.query("""
373380
SELECT table_name FROM information_schema.tables
374-
WHERE table_schema = %s and table_name NOT LIKE '~%%'""", args=(self.database))]
381+
WHERE table_schema = %s and table_name NOT LIKE '~%%'""", args=(self.database,))]
375382

376383

377384
class VirtualModule(types.ModuleType):
@@ -393,8 +400,8 @@ def __init__(self, module_name, schema_name, *, create_schema=False,
393400
:return: the python module containing classes from the schema object and the table classes
394401
"""
395402
super(VirtualModule, self).__init__(name=module_name)
396-
_schema = Schema(schema_name, create_schema=create_schema, create_tables=create_tables,
397-
connection=connection)
403+
_schema = Schema(schema_name, create_schema=create_schema,
404+
create_tables=create_tables, connection=connection)
398405
if add_objects:
399406
self.__dict__.update(add_objects)
400407
self.__dict__['schema'] = _schema
@@ -406,4 +413,7 @@ def list_schemas(connection=None):
406413
:param connection: a dj.Connection object
407414
:return: list of all accessible schemas on the server
408415
"""
409-
return [r[0] for r in (connection or conn()).query('SHOW SCHEMAS') if r[0] not in {'information_schema'}]
416+
return [r[0] for r in (connection or conn()).query(
417+
'SELECT schema_name '
418+
'FROM information_schema.schemata '
419+
'WHERE schema_name <> "information_schema"')]

datajoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.13.dev4"
1+
__version__ = "0.13.dev5"
22

33
assert len(__version__) <= 10 # The log table limits version to the 10 characters

tests/schema_university.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import datajoint as dj
2-
from . import PREFIX, CONN_INFO
32

4-
schema = dj.Schema(connection=dj.conn(**CONN_INFO))
3+
schema = dj.Schema()
54

65

76
@schema
@@ -109,15 +108,4 @@ class Grade(dj.Manual):
109108
-> Enroll
110109
---
111110
-> LetterGrade
112-
"""
113-
114-
115-
schema.activate(PREFIX + '_university') # deferred activation
116-
117-
# --------------- Fill University -------------------
118-
119-
for table in Student, Department, StudentMajor, Course, Term, CurrentTerm, Section, Enroll, Grade:
120-
import csv
121-
with open('./data/' + table.__name__ + '.csv') as f:
122-
reader = csv.DictReader(f)
123-
table().insert(reader)
111+
"""

tests/test_schema.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,27 @@ def test_schema_size_on_disk():
2626
assert_true(isinstance(number_of_bytes, int))
2727

2828

29+
def test_schema_list():
30+
schemas = dj.list_schemas()
31+
assert_true(schema.schema.database in schemas)
32+
33+
34+
@raises(dj.errors.AccessError)
35+
def test_drop_unauthorized():
36+
info_schema = dj.schema('information_schema')
37+
info_schema.drop()
38+
39+
2940
def test_namespace_population():
3041
for name, rel in getmembers(schema, relation_selector):
3142
assert_true(hasattr(schema_empty, name), '{name} not found in schema_empty'.format(name=name))
32-
assert_true(rel.__base__ is getattr(schema_empty, name).__base__, 'Wrong tier for {name}'.format(name=name))
43+
assert_true(rel.__base__ is getattr(schema_empty, name).__base__,
44+
'Wrong tier for {name}'.format(name=name))
3345

3446
for name_part in dir(rel):
3547
if name_part[0].isupper() and part_selector(getattr(rel, name_part)):
36-
assert_true(getattr(rel, name_part).__base__ is dj.Part, 'Wrong tier for {name}'.format(name=name_part))
48+
assert_true(getattr(rel, name_part).__base__ is dj.Part,
49+
'Wrong tier for {name}'.format(name=name_part))
3750

3851

3952
@raises(dj.DataJointError)

tests/test_university.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from nose.tools import assert_true, assert_list_equal, assert_false
2-
from .schema_university import *
1+
from nose.tools import assert_true, assert_list_equal, assert_false, raises
32
import hashlib
3+
from datajoint import DataJointError
4+
from .schema_university import *
5+
from . import PREFIX, CONN_INFO
46

57

68
def _hash4(table):
@@ -10,10 +12,23 @@ def _hash4(table):
1012
return hashlib.md5(blob).digest().hex()[:4]
1113

1214

15+
@raises(DataJointError)
16+
def test_activate_unauthorized():
17+
schema.activate('unauthorized', connection=dj.conn(**CONN_INFO))
18+
19+
20+
def test_activate():
21+
schema.activate(PREFIX + '_university', connection=dj.conn(**CONN_INFO)) # deferred activation
22+
# --------------- Fill University -------------------
23+
for table in Student, Department, StudentMajor, Course, Term, CurrentTerm, Section, Enroll, Grade:
24+
import csv
25+
with open('./data/' + table.__name__ + '.csv') as f:
26+
reader = csv.DictReader(f)
27+
table().insert(reader)
28+
29+
1330
def test_fill():
14-
"""
15-
check that the randomized tables are consistently defined
16-
"""
31+
""" check that the randomized tables are consistently defined """
1732
# check randomized tables
1833
assert_true(len(Student()) == 300 and _hash4(Student) == '1e1a')
1934
assert_true(len(StudentMajor()) == 226 and _hash4(StudentMajor) == '3129')
@@ -27,7 +42,6 @@ def test_restrict():
2742
test diverse restrictions from the university database.
2843
This test relies on a specific instantiation of the database.
2944
"""
30-
3145
utahns1 = Student & {'home_state': 'UT'}
3246
utahns2 = Student & 'home_state="UT"'
3347
assert_true(len(utahns1) == len(utahns2.fetch('KEY')) == 7)
@@ -41,7 +55,8 @@ def test_restrict():
4155
assert_true(set(sex1).pop() == set(sex2).pop() == "M")
4256

4357
# students from OK, NM, TX
44-
s1 = (Student & [{'home_state': s} for s in ('OK', 'NM', 'TX')]).fetch("KEY", order_by="student_id")
58+
s1 = (Student & [{'home_state': s} for s in ('OK', 'NM', 'TX')]).fetch(
59+
"KEY", order_by="student_id")
4560
s2 = (Student & 'home_state in ("OK", "NM", "TX")').fetch('KEY', order_by="student_id")
4661
assert_true(len(s1) == 11)
4762
assert_list_equal(s1, s2)

0 commit comments

Comments
 (0)