Skip to content

Commit 0e194e7

Browse files
committed
Merge branch 'dev-tests' into dev-tests-plat-166-schema
2 parents b9ccb4f + 20c5039 commit 0e194e7

File tree

1 file changed

+344
-0
lines changed

1 file changed

+344
-0
lines changed

tests/test_relation.py

Lines changed: 344 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
import pytest
2+
from inspect import getmembers
3+
import re
4+
import pandas
5+
import numpy as np
6+
import datajoint as dj
7+
from datajoint.table import Table
8+
from unittest.mock import patch
9+
10+
from . import schema
11+
12+
13+
@pytest.fixture
14+
def test(schema_any):
15+
yield schema.TTest()
16+
17+
18+
@pytest.fixture
19+
def test2(schema_any):
20+
yield schema.TTest2()
21+
22+
23+
@pytest.fixture
24+
def test_extra(schema_any):
25+
yield schema.TTestExtra()
26+
27+
28+
@pytest.fixture
29+
def test_no_extra(schema_any):
30+
yield schema.TTestNoExtra()
31+
32+
33+
@pytest.fixture
34+
def user(schema_any):
35+
return schema.User()
36+
37+
38+
@pytest.fixture
39+
def subject(schema_any):
40+
return schema.Subject()
41+
42+
43+
@pytest.fixture
44+
def experiment(schema_any):
45+
return schema.Experiment()
46+
47+
48+
@pytest.fixture
49+
def ephys(schema_any):
50+
return schema.Ephys()
51+
52+
53+
@pytest.fixture
54+
def img(schema_any):
55+
return schema.Image()
56+
57+
58+
@pytest.fixture
59+
def trash(schema_any):
60+
return schema.UberTrash()
61+
62+
63+
def test_contents(user, subject):
64+
"""
65+
test the ability of tables to self-populate using the contents property
66+
"""
67+
# test contents
68+
assert user
69+
assert len(user) == len(user.contents)
70+
u = user.fetch(order_by=["username"])
71+
assert list(u["username"]) == sorted([s[0] for s in user.contents])
72+
73+
# test prepare
74+
assert subject
75+
assert len(subject) == len(subject.contents)
76+
u = subject.fetch(order_by=["subject_id"])
77+
assert list(u["subject_id"]) == sorted([s[0] for s in subject.contents])
78+
79+
80+
def test_misnamed_attribute1(user):
81+
with pytest.raises(dj.DataJointError):
82+
user.insert([dict(username="Bob"), dict(user="Alice")])
83+
84+
85+
def test_misnamed_attribute2(user):
86+
with pytest.raises(KeyError):
87+
user.insert1(dict(user="Bob"))
88+
89+
90+
def test_extra_attribute1(user):
91+
with pytest.raises(KeyError):
92+
user.insert1(dict(username="Robert", spouse="Alice"))
93+
94+
95+
def test_extra_attribute2(user):
96+
user.insert1(dict(username="Robert", spouse="Alice"), ignore_extra_fields=True)
97+
98+
99+
def test_missing_definition(schema_any):
100+
class MissingDefinition(dj.Manual):
101+
definitions = """ # misspelled definition
102+
id : int
103+
---
104+
comment : varchar(16) # otherwise everything's normal
105+
"""
106+
107+
with pytest.raises(NotImplementedError):
108+
schema_any(MissingDefinition, context=dict(MissingDefinition=MissingDefinition))
109+
110+
111+
def test_empty_insert1(user):
112+
with pytest.raises(dj.DataJointError):
113+
user.insert1(())
114+
115+
116+
def test_empty_insert(user):
117+
with pytest.raises(dj.DataJointError):
118+
user.insert([()])
119+
120+
121+
def test_wrong_arguments_insert(user):
122+
with pytest.raises(dj.DataJointError):
123+
user.insert1(("First", "Second"))
124+
125+
126+
def test_wrong_insert_type(user):
127+
with pytest.raises(dj.DataJointError):
128+
user.insert1(3)
129+
130+
131+
def test_insert_select(subject, test, test2):
132+
test2.delete()
133+
test2.insert(test)
134+
assert len(test2) == len(test)
135+
136+
original_length = len(subject)
137+
elements = subject.proj(..., s="subject_id")
138+
elements = elements.proj(
139+
"real_id",
140+
"date_of_birth",
141+
"subject_notes",
142+
subject_id="s+1000",
143+
species='"human"',
144+
)
145+
subject.insert(elements, ignore_extra_fields=True)
146+
assert len(subject) == 2 * original_length
147+
148+
149+
def test_insert_pandas_roundtrip(test, test2):
150+
"""ensure fetched frames can be inserted"""
151+
test2.delete()
152+
n = len(test)
153+
assert n > 0
154+
df = test.fetch(format="frame")
155+
assert isinstance(df, pandas.DataFrame)
156+
assert len(df) == n
157+
test2.insert(df)
158+
assert len(test2) == n
159+
160+
161+
def test_insert_pandas_userframe(test, test2):
162+
"""
163+
ensure simple user-created frames (1 field, non-custom index)
164+
can be inserted without extra index adjustment
165+
"""
166+
test2.delete()
167+
n = len(test)
168+
assert n > 0
169+
df = pandas.DataFrame(test.fetch())
170+
assert isinstance(df, pandas.DataFrame)
171+
assert len(df) == n
172+
test2.insert(df)
173+
assert len(test2) == n
174+
175+
176+
def test_insert_select_ignore_extra_fields0(test, test_extra):
177+
"""need ignore extra fields for insert select"""
178+
test_extra.insert1((test.fetch("key").max() + 1, 0, 0))
179+
with pytest.raises(dj.DataJointError):
180+
test.insert(test_extra)
181+
182+
183+
def test_insert_select_ignore_extra_fields1(test, test_extra):
184+
"""make sure extra fields works in insert select"""
185+
test_extra.delete()
186+
keyno = test.fetch("key").max() + 1
187+
test_extra.insert1((keyno, 0, 0))
188+
test.insert(test_extra, ignore_extra_fields=True)
189+
assert keyno in test.fetch("key")
190+
191+
192+
def test_insert_select_ignore_extra_fields2(test_no_extra, test):
193+
"""make sure insert select still works when ignoring extra fields when there are none"""
194+
test_no_extra.delete()
195+
test_no_extra.insert(test, ignore_extra_fields=True)
196+
197+
198+
def test_insert_select_ignore_extra_fields3(test, test_no_extra, test_extra):
199+
"""make sure insert select works for from query result"""
200+
# Recreate table state from previous tests
201+
keyno = test.fetch("key").max() + 1
202+
test_extra.insert1((keyno, 0, 0))
203+
test.insert(test_extra, ignore_extra_fields=True)
204+
205+
assert len(test_extra.fetch("key")), "test_extra is empty"
206+
test_no_extra.delete()
207+
assert len(test_extra.fetch("key")), "test_extra is empty"
208+
keystr = str(test_extra.fetch("key").max())
209+
test_no_extra.insert((test_extra & "`key`=" + keystr), ignore_extra_fields=True)
210+
211+
212+
def test_skip_duplicates(test_no_extra, test):
213+
"""test that skip_duplicates works when inserting from another table"""
214+
test_no_extra.delete()
215+
test_no_extra.insert(test, ignore_extra_fields=True, skip_duplicates=True)
216+
test_no_extra.insert(test, ignore_extra_fields=True, skip_duplicates=True)
217+
218+
219+
def test_replace(subject):
220+
"""
221+
Test replacing or ignoring duplicate entries
222+
"""
223+
key = dict(subject_id=7)
224+
date = "2015-01-01"
225+
subject.insert1(dict(key, real_id=7, date_of_birth=date, subject_notes=""))
226+
assert date == str((subject & key).fetch1("date_of_birth")), "incorrect insert"
227+
date = "2015-01-02"
228+
subject.insert1(
229+
dict(key, real_id=7, date_of_birth=date, subject_notes=""),
230+
skip_duplicates=True,
231+
)
232+
assert date != str((subject & key).fetch1("date_of_birth")), "inappropriate replace"
233+
subject.insert1(
234+
dict(key, real_id=7, date_of_birth=date, subject_notes=""), replace=True
235+
)
236+
assert date == str((subject & key).fetch1("date_of_birth")), "replace failed"
237+
238+
239+
def test_delete_quick(subject):
240+
"""Tests quick deletion"""
241+
tmp = np.array(
242+
[
243+
(2, "Klara", "monkey", "2010-01-01", ""),
244+
(1, "Peter", "mouse", "2015-01-01", ""),
245+
],
246+
dtype=subject.heading.as_dtype,
247+
)
248+
subject.insert(tmp)
249+
s = subject & ("subject_id in (%s)" % ",".join(str(r) for r in tmp["subject_id"]))
250+
assert len(s) == 2, "insert did not work."
251+
s.delete_quick()
252+
assert len(s) == 0, "delete did not work."
253+
254+
255+
def test_skip_duplicate(subject):
256+
"""Tests if duplicates are properly skipped."""
257+
tmp = np.array(
258+
[
259+
(2, "Klara", "monkey", "2010-01-01", ""),
260+
(1, "Peter", "mouse", "2015-01-01", ""),
261+
],
262+
dtype=subject.heading.as_dtype,
263+
)
264+
subject.insert(tmp)
265+
tmp = np.array(
266+
[
267+
(2, "Klara", "monkey", "2010-01-01", ""),
268+
(1, "Peter", "mouse", "2015-01-01", ""),
269+
],
270+
dtype=subject.heading.as_dtype,
271+
)
272+
subject.insert(tmp, skip_duplicates=True)
273+
274+
275+
def test_not_skip_duplicate(subject):
276+
"""Tests if duplicates are not skipped."""
277+
tmp = np.array(
278+
[
279+
(2, "Klara", "monkey", "2010-01-01", ""),
280+
(2, "Klara", "monkey", "2010-01-01", ""),
281+
(1, "Peter", "mouse", "2015-01-01", ""),
282+
],
283+
dtype=subject.heading.as_dtype,
284+
)
285+
with pytest.raises(dj.errors.DuplicateError):
286+
subject.insert(tmp, skip_duplicates=False)
287+
288+
289+
def test_no_error_suppression(test):
290+
"""skip_duplicates=True should not suppress other errors"""
291+
with pytest.raises(dj.errors.MissingAttributeError):
292+
test.insert([dict(key=100)], skip_duplicates=True)
293+
294+
295+
def test_blob_insert(img):
296+
"""Tests inserting and retrieving blobs."""
297+
X = np.random.randn(20, 10)
298+
img.insert1((1, X))
299+
Y = img.fetch()[0]["img"]
300+
assert np.all(X == Y), "Inserted and retrieved image are not identical"
301+
302+
303+
def test_drop(trash):
304+
"""Tests dropping tables"""
305+
dj.config["safemode"] = True
306+
with patch.object(dj.utils, "input", create=True, return_value="yes"):
307+
trash.drop()
308+
try:
309+
trash.fetch()
310+
raise Exception("Fetched after table dropped.")
311+
except dj.DataJointError:
312+
pass
313+
finally:
314+
dj.config["safemode"] = False
315+
316+
317+
def test_table_regexp(schema_any):
318+
"""Test whether table names are matched by regular expressions"""
319+
320+
def relation_selector(attr):
321+
try:
322+
return issubclass(attr, Table)
323+
except TypeError:
324+
return False
325+
326+
tiers = [dj.Imported, dj.Manual, dj.Lookup, dj.Computed]
327+
for name, rel in getmembers(schema, relation_selector):
328+
assert re.match(
329+
rel.tier_regexp, rel.table_name
330+
), "Regular expression does not match for {name}".format(name=name)
331+
for tier in tiers:
332+
assert issubclass(rel, tier) or not re.match(
333+
tier.tier_regexp, rel.table_name
334+
), "Regular expression matches for {name} but should not".format(name=name)
335+
336+
337+
def test_table_size(experiment):
338+
"""test getting the size of the table and its indices in bytes"""
339+
number_of_bytes = experiment.size_on_disk
340+
assert isinstance(number_of_bytes, int) and number_of_bytes > 100
341+
342+
343+
def test_repr_html(ephys):
344+
assert ephys._repr_html_().strip().startswith("<style")

0 commit comments

Comments
 (0)