Skip to content

Commit 007b495

Browse files
Merge pull request #735 from guzman-raphael/plugin2
Restructure Plugin Infrastructure
2 parents c871023 + ffa2d92 commit 007b495

38 files changed

+219
-155
lines changed

CHANGELOG.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
11
## Release notes
22

3+
### 0.12.5plug -- Feb 24, 2020
4+
* Support DataJoint datatype and connection plugins (#715, #729) PR 730, #735
5+
6+
### 0.12.5 -- Feb 24, 2020
7+
* Rename module `dj.schema` into `dj.schemas`. `dj.schema` remains an alias for class `dj.Schema`. (#731) PR #732
8+
* `dj.create_virtual_module` is now called `dj.VirtualModule` (#731) PR #732
9+
* Bugfix - SSL `KeyError` on failed connection (#716) PR #725
10+
* Bugfix - Unable to run unit tests using nosetests (#723) PR #724
11+
* Bugfix - `suppress_errors` does not suppress loss of connection error (#720) PR #721
12+
313
### 0.12.4 -- Jan 14, 2020
414
* Support for simple scalar datatypes in blobs (#690) PR #709
515
* Add support for the `serial` data type in declarations: alias for `bigint unsigned auto_increment` PR #713

LNX-docker-compose.yml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ services:
1414
fakeservices.datajoint.io:
1515
condition: service_healthy
1616
environment:
17-
- DJ_HOST=db
17+
- DJ_HOST=fakeservices.datajoint.io
1818
- DJ_USER=root
1919
- DJ_PASS=simple
20-
- DJ_TEST_HOST=db
20+
- DJ_TEST_HOST=fakeservices.datajoint.io
2121
- DJ_TEST_USER=datajoint
2222
- DJ_TEST_PASSWORD=datajoint
2323
- S3_ENDPOINT=fakeservices.datajoint.io:9000
@@ -56,10 +56,10 @@ services:
5656
# - ./mysql/data:/var/lib/mysql
5757
minio:
5858
<<: *net
59+
image: minio/minio:$MINIO_VER
5960
environment:
6061
- MINIO_ACCESS_KEY=datajoint
6162
- MINIO_SECRET_KEY=datajoint
62-
image: minio/minio:$MINIO_VER
6363
# ports:
6464
# - "9000:9000"
6565
# volumes:
@@ -78,6 +78,7 @@ services:
7878
- URL=datajoint.io
7979
- SUBDOMAINS=fakeservices
8080
- MINIO_SERVER=http://minio:9000
81+
- MYSQL_SERVER=db:3306
8182
entrypoint: /entrypoint.sh
8283
healthcheck:
8384
test: wget --quiet --tries=1 --spider https://fakeservices.datajoint.io:443/minio/health/live || exit 1
@@ -87,8 +88,10 @@ services:
8788
# ports:
8889
# - "9000:9000"
8990
# - "443:443"
91+
# - "3306:3306"
9092
volumes:
9193
- ./tests/nginx/base.conf:/base.conf
94+
- ./tests/nginx/nginx.conf:/nginx.conf
9295
- ./tests/nginx/entrypoint.sh:/entrypoint.sh
9396
- ./tests/nginx/fullchain.pem:/certs/fullchain.pem
9497
- ./tests/nginx/privkey.pem:/certs/privkey.pem

datajoint/__init__.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
__date__ = "February 7, 2019"
1919
__all__ = ['__author__', '__version__',
2020
'config', 'conn', 'Connection',
21-
'schema', 'create_virtual_module', 'list_schemas',
22-
'Table', 'FreeTable',
21+
'Schema', 'schema', 'VirtualModule', 'create_virtual_module',
22+
'list_schemas', 'Table', 'FreeTable',
2323
'Manual', 'Lookup', 'Imported', 'Computed', 'Part',
2424
'Not', 'AndList', 'U', 'Diagram', 'Di', 'ERD',
2525
'set_password', 'kill',
@@ -29,8 +29,8 @@
2929
from .version import __version__
3030
from .settings import config
3131
from .connection import conn, Connection
32-
from .schema import Schema as schema
33-
from .schema import create_virtual_module, list_schemas
32+
from .schemas import Schema
33+
from .schemas import VirtualModule, list_schemas
3434
from .table import Table, FreeTable
3535
from .user_tables import Manual, Lookup, Imported, Computed, Part
3636
from .expression import Not, AndList, U
@@ -43,4 +43,6 @@
4343
from .errors import DataJointError
4444
from .migrate import migrate_dj011_external_blob_storage_to_dj012
4545

46-
ERD = Di = Diagram # Aliases for Diagram
46+
ERD = Di = Diagram # Aliases for Diagram
47+
schema = Schema # Aliases for Schema
48+
create_virtual_module = VirtualModule # Aliases for VirtualModule

datajoint/attribute_adapter.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import re
2-
import os
32
from .errors import DataJointError, _support_adapted_types
4-
from .plugin import override
3+
from .plugin import type_plugins
54

65

76
class AttributeAdapter:
@@ -32,9 +31,6 @@ def put(self, obj):
3231
raise NotImplementedError('Undefined attribute adapter')
3332

3433

35-
override('attribute_adapter', globals())
36-
37-
3834
def get_adapter(context, adapter_name):
3935
"""
4036
Extract the AttributeAdapter object by its name from the context and validate.
@@ -43,9 +39,8 @@ def get_adapter(context, adapter_name):
4339
raise DataJointError('Support for Adapted Attribute types is disabled.')
4440
adapter_name = adapter_name.lstrip('<').rstrip('>')
4541
try:
46-
source_module_name, adapter_name = os.path.splitext(adapter_name)
47-
adapter = context[source_module_name] if adapter_name == '' else getattr(
48-
__import__(source_module_name, fromlist=[adapter_name[1:]]), adapter_name[1:])
42+
adapter = (context[adapter_name] if adapter_name in context
43+
else type_plugins[adapter_name]['object'].load())
4944
except KeyError:
5045
raise DataJointError(
5146
"Attribute adapter '{adapter_name}' is not defined.".format(adapter_name=adapter_name))

datajoint/connection.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,21 +11,34 @@
1111
from .settings import config
1212
from . import errors
1313
from .dependencies import Dependencies
14-
from .plugin import override
14+
from .plugin import connection_plugins
1515

1616
# client errors to catch
1717
client_errors = (client.err.InterfaceError, client.err.DatabaseError)
1818

1919

2020
def get_host_hook(host_input):
21-
return host_input
21+
if '://' in host_input:
22+
plugin_name = host_input.split('://')[0]
23+
try:
24+
return connection_plugins[plugin_name]['object'].load().get_host(host_input)
25+
except KeyError:
26+
raise DataJointError(
27+
"Connection plugin '{}' not found.".format(plugin_name))
28+
else:
29+
return host_input
2230

2331

2432
def connect_host_hook(connection_obj):
25-
connection_obj.connect()
26-
27-
28-
override('connection', globals(), ['get_host_hook', 'connect_host_hook'])
33+
if '://' in connection_obj.conn_info['host_input']:
34+
plugin_name = connection_obj.conn_info['host_input'].split('://')[0]
35+
try:
36+
connection_plugins[plugin_name]['object'].load().connect_host(connection_obj)
37+
except KeyError:
38+
raise DataJointError(
39+
"Connection plugin '{}' not found.".format(plugin_name))
40+
else:
41+
connection_obj.connect()
2942

3043

3144
def translate_query_error(client_error, query):

datajoint/errors.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ class DataJointError(Exception):
1616
Base class for errors specific to DataJoint internal operation.
1717
"""
1818
def __init__(self, *args):
19-
from .plugin import discovered_plugins
19+
from .plugin import connection_plugins, type_plugins
2020
self.__cause__ = PluginWarning(
21-
'Unverified DataJoint plugin detected.') if discovered_plugins and any(
22-
[not discovered_plugins[k]['verified'] for k in discovered_plugins]) else None
21+
'Unverified DataJoint plugin detected.') if any([any(
22+
[not plugins[k]['verified'] for k in plugins])
23+
for plugins in [connection_plugins, type_plugins]
24+
if plugins]) else None
2325

2426
def suggest(self, *args):
2527
"""

datajoint/migrate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def migrate_dj011_external_blob_storage_to_dj012(migration_schema, store):
2525
Proceed?
2626
""", default='no') == 'yes'
2727
if do_migration:
28-
_migrate_dj011_blob(dj.schema(migration_schema), store)
28+
_migrate_dj011_blob(dj.Schema(migration_schema), store)
2929
print('Migration completed for schema: {}, store: {}.'.format(
3030
migration_schema, store))
3131
return

datajoint/plugin.py

Lines changed: 19 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,8 @@
1+
from .settings import config
12
import pkg_resources
23
from pathlib import Path
34
from cryptography.exceptions import InvalidSignature
45
from setuptools_certificate import hash_pkg, verify
5-
from .settings import config
6-
7-
discovered_plugins = {
8-
entry_point.module_name: dict(plugon=entry_point.name, verified=False)
9-
for entry_point
10-
in pkg_resources.iter_entry_points('datajoint.plugins')
11-
if 'plugin' not in config or entry_point.name not in config['plugin'] or
12-
entry_point.module_name in config['plugin'][entry_point.name]
13-
}
146

157

168
def _update_error_stack(plugin_name):
@@ -23,34 +15,24 @@ def _update_error_stack(plugin_name):
2315
signature = plugin_meta.get_metadata('{}.sig'.format(plugin_name))
2416
pubkey_path = str(Path(base_meta.egg_info, '{}.pub'.format(base_name)))
2517
verify(pubkey_path, data, signature)
26-
discovered_plugins[plugin_name]['verified'] = True
27-
print('DataJoint verified plugin `{}` introduced.'.format(plugin_name))
18+
print('DataJoint verified plugin `{}` detected.'.format(plugin_name))
19+
return True
2820
except (FileNotFoundError, InvalidSignature):
29-
print('Unverified plugin `{}` introduced.'.format(plugin_name))
21+
print('Unverified plugin `{}` detected.'.format(plugin_name))
22+
return False
23+
24+
25+
def _import_plugins(category):
26+
return {
27+
entry_point.name: dict(object=entry_point,
28+
verified=_update_error_stack(
29+
entry_point.module_name.split('.')[0]))
30+
for entry_point
31+
in pkg_resources.iter_entry_points('datajoint_plugins.{}'.format(category))
32+
if 'plugin' not in config or category not in config['plugin'] or
33+
entry_point.module_name.split('.')[0] in config['plugin'][category]
34+
}
3035

3136

32-
def override(plugin_type, context, method_list=None):
33-
relevant_plugins = {
34-
k: v for k, v in discovered_plugins.items() if v['plugon'] == plugin_type}
35-
if relevant_plugins:
36-
for module_name in relevant_plugins:
37-
# import plugin
38-
module = __import__(module_name)
39-
module_dict = module.__dict__
40-
# update error stack (if applicable)
41-
_update_error_stack(module.__name__)
42-
# override based on plugon preference
43-
if method_list is not None:
44-
new_methods = []
45-
for v in method_list:
46-
try:
47-
new_methods.append(getattr(module, v))
48-
except AttributeError:
49-
pass
50-
context.update(dict(zip(method_list, new_methods)))
51-
else:
52-
try:
53-
new_methods = module.__all__
54-
except AttributeError:
55-
new_methods = [name for name in module_dict if not name.startswith('_')]
56-
context.update({name: module_dict[name] for name in new_methods})
37+
connection_plugins = _import_plugins('connection')
38+
type_plugins = _import_plugins('datatype')
Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from .user_tables import Part, Computed, Imported, Manual, Lookup
1717
from .table import lookup_class_name, Log, FreeTable
1818
import types
19-
from .plugin import override
2019

2120
logger = logging.getLogger(__name__)
2221

@@ -281,8 +280,8 @@ def repl(s):
281280
body = '\n\n\n'.join(make_class_definition(table) for table in diagram.topological_sort())
282281
python_code = '\n\n\n'.join((
283282
'"""This module was auto-generated by datajoint from an existing schema"""',
284-
"import datajoint as dj\n\nschema = dj.schema('{db}')".format(db=db),
285-
'\n'.join("{module} = dj.create_virtual_module('{module}', '{schema_name}')".format(module=v, schema_name=k)
283+
"import datajoint as dj\n\nschema = dj.Schema('{db}')".format(db=db),
284+
'\n'.join("{module} = dj.VirtualModule('{module}', '{schema_name}')".format(module=v, schema_name=k)
286285
for k, v in module_lookup.items()), body))
287286
if python_filename is None:
288287
return python_code
@@ -291,29 +290,30 @@ def repl(s):
291290
f.write(python_code)
292291

293292

294-
def create_virtual_module(module_name, schema_name, *,
295-
create_schema=False, create_tables=False, connection=None, add_objects=None):
293+
class VirtualModule(types.ModuleType):
296294
"""
297-
Creates a python module with the given name from the name of a schema on the server and
298-
automatically adds classes to it corresponding to the tables in the schema.
299-
:param module_name: displayed module name
300-
:param schema_name: name of the database in mysql
301-
:param create_schema: if True, create the schema on the database server
302-
:param create_tables: if True, module.schema can be used as the decorator for declaring new
303-
:param connection: a dj.Connection object to pass into the schema
304-
:param add_objects: additional objects to add to the module
305-
:return: the python module containing classes from the schema object and the table classes
295+
A virtual module which will contain context for schema.
306296
"""
307-
module = types.ModuleType(module_name)
308-
_schema = Schema(schema_name, create_schema=create_schema, create_tables=create_tables, connection=connection)
309-
if add_objects:
310-
module.__dict__.update(add_objects)
311-
module.__dict__['schema'] = _schema
312-
_schema.spawn_missing_classes(context=module.__dict__)
313-
return module
314-
315-
316-
override('schema', globals())
297+
def __init__(self, module_name, schema_name, *, create_schema=False,
298+
create_tables=False, connection=None, add_objects=None):
299+
"""
300+
Creates a python module with the given name from the name of a schema on the server and
301+
automatically adds classes to it corresponding to the tables in the schema.
302+
:param module_name: displayed module name
303+
:param schema_name: name of the database in mysql
304+
:param create_schema: if True, create the schema on the database server
305+
:param create_tables: if True, module.schema can be used as the decorator for declaring new
306+
:param connection: a dj.Connection object to pass into the schema
307+
:param add_objects: additional objects to add to the module
308+
:return: the python module containing classes from the schema object and the table classes
309+
"""
310+
super(VirtualModule, self).__init__(name=module_name)
311+
_schema = Schema(schema_name, create_schema=create_schema, create_tables=create_tables,
312+
connection=connection)
313+
if add_objects:
314+
self.__dict__.update(add_objects)
315+
self.__dict__['schema'] = _schema
316+
_schema.spawn_missing_classes(context=self.__dict__)
317317

318318

319319
def list_schemas(connection=None):

datajoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
__version__ = "0.12.5"
1+
__version__ = "0.12.5plug"
22

33
assert len(__version__) <= 10 # The log table limits version to the 10 characters

0 commit comments

Comments
 (0)