Skip to content

Commit b08150a

Browse files
committed
Revert to using datajoint.plugin
1 parent 4cdfe4f commit b08150a

File tree

4 files changed

+122
-1
lines changed

4 files changed

+122
-1
lines changed

datajoint.pub

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
-----BEGIN PUBLIC KEY-----
2+
MIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKBgQDUMOo2U7YQ1uOrKU/IreM3AQP2
3+
AXJC3au+S9W+dilxHcJ3e98bRVqrFeOofcGeRPoNc38fiLmLDUiBskJeVrpm29Wo
4+
AkH6yhZWk1o8NvGMhK4DLsJYlsH6tZuOx9NITKzJuOOH6X1I5Ucs7NOSKnmu7g5g
5+
WTT5kCgF5QAe5JN8WQIDAQAB
6+
-----END PUBLIC KEY-----

datajoint/errors.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,19 @@ class DataJointError(Exception):
1717
"""
1818

1919
def __init__(self, *args):
20-
super().__init__(*args)
20+
from .plugin import connection_plugins, type_plugins
21+
22+
self.__cause__ = (
23+
PluginWarning("Unverified DataJoint plugin detected.")
24+
if any(
25+
[
26+
any([not plugins[k]["verified"] for k in plugins])
27+
for plugins in [connection_plugins, type_plugins]
28+
if plugins
29+
]
30+
)
31+
else None
32+
)
2133

2234
def suggest(self, *args):
2335
"""

datajoint/plugin.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from .settings import config
2+
import pkg_resources
3+
from pathlib import Path
4+
from cryptography.exceptions import InvalidSignature
5+
from otumat import hash_pkg, verify
6+
import logging
7+
8+
logger = logging.getLogger(__name__.split(".")[0])
9+
10+
11+
def _update_error_stack(plugin_name):
12+
try:
13+
base_name = "datajoint"
14+
base_meta = pkg_resources.get_distribution(base_name)
15+
plugin_meta = pkg_resources.get_distribution(plugin_name)
16+
17+
data = hash_pkg(pkgpath=str(Path(plugin_meta.module_path, plugin_name)))
18+
signature = plugin_meta.get_metadata(f"{plugin_name}.sig")
19+
pubkey_path = str(Path(base_meta.egg_info, f"{base_name}.pub"))
20+
verify(pubkey_path=pubkey_path, data=data, signature=signature)
21+
logger.info(f"DataJoint verified plugin `{plugin_name}` detected.")
22+
return True
23+
except (FileNotFoundError, InvalidSignature):
24+
logger.warning(f"Unverified plugin `{plugin_name}` detected.")
25+
return False
26+
27+
28+
def _import_plugins(category):
29+
return {
30+
entry_point.name: dict(
31+
object=entry_point,
32+
verified=_update_error_stack(entry_point.module_name.split(".")[0]),
33+
)
34+
for entry_point in pkg_resources.iter_entry_points(
35+
"datajoint_plugins.{}".format(category)
36+
)
37+
if "plugin" not in config
38+
or category not in config["plugin"]
39+
or entry_point.module_name.split(".")[0] in config["plugin"][category]
40+
}
41+
42+
43+
connection_plugins = _import_plugins("connection")
44+
type_plugins = _import_plugins("datatype")

tests/test_plugin.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import pytest
2+
import datajoint.errors as djerr
3+
import datajoint.plugin as p
4+
import pkg_resources
5+
from os import path
6+
7+
8+
def test_check_pubkey():
9+
base_name = "datajoint"
10+
base_meta = pkg_resources.get_distribution(base_name)
11+
pubkey_meta = base_meta.get_metadata("{}.pub".format(base_name))
12+
13+
with open(
14+
path.join(path.abspath(path.dirname(__file__)), "..", "datajoint.pub"), "r"
15+
) as f:
16+
assert f.read() == pubkey_meta
17+
18+
19+
def test_normal_djerror():
20+
try:
21+
raise djerr.DataJointError
22+
except djerr.DataJointError as e:
23+
assert e.__cause__ is None
24+
25+
26+
def test_verified_djerror(category="connection"):
27+
try:
28+
curr_plugins = getattr(p, "{}_plugins".format(category))
29+
setattr(
30+
p,
31+
"{}_plugins".format(category),
32+
dict(test_plugin_id=dict(verified=True, object="example")),
33+
)
34+
raise djerr.DataJointError
35+
except djerr.DataJointError as e:
36+
setattr(p, "{}_plugins".format(category), curr_plugins)
37+
assert e.__cause__ is None
38+
39+
40+
def test_verified_djerror_type():
41+
test_verified_djerror(category="type")
42+
43+
44+
def test_unverified_djerror(category="connection"):
45+
try:
46+
curr_plugins = getattr(p, "{}_plugins".format(category))
47+
setattr(
48+
p,
49+
"{}_plugins".format(category),
50+
dict(test_plugin_id=dict(verified=False, object="example")),
51+
)
52+
raise djerr.DataJointError("hello")
53+
except djerr.DataJointError as e:
54+
setattr(p, "{}_plugins".format(category), curr_plugins)
55+
assert isinstance(e.__cause__, djerr.PluginWarning)
56+
57+
58+
def test_unverified_djerror_type():
59+
test_unverified_djerror(category="type")

0 commit comments

Comments
 (0)