Skip to content

Commit fa2ae33

Browse files
authored
Merge pull request #1139 from ethho/dev-tests-plat-169-update1
PLAT-169: Migrate test_update1
2 parents d474c7d + 20ab185 commit fa2ae33

File tree

1 file changed

+147
-0
lines changed

1 file changed

+147
-0
lines changed

tests/test_update1.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import pytest
2+
import os
3+
import numpy as np
4+
from pathlib import Path
5+
import tempfile
6+
import datajoint as dj
7+
from . import PREFIX
8+
from datajoint import DataJointError
9+
10+
11+
class Thing(dj.Manual):
12+
definition = """
13+
thing : int
14+
---
15+
number=0 : int
16+
frac : float
17+
picture = null : attach@update_store
18+
params = null : longblob
19+
img_file = null: filepath@update_repo
20+
timestamp = CURRENT_TIMESTAMP : datetime
21+
"""
22+
23+
24+
@pytest.fixture(scope="module")
25+
def mock_stores_update(tmpdir_factory):
26+
og_stores_config = dj.config.get("stores")
27+
if "stores" not in dj.config:
28+
dj.config["stores"] = {}
29+
dj.config["stores"]["update_store"] = dict(
30+
protocol="file", location=tmpdir_factory.mktemp("store")
31+
)
32+
dj.config["stores"]["update_repo"] = dict(
33+
stage=tmpdir_factory.mktemp("repo_stage"),
34+
protocol="file",
35+
location=tmpdir_factory.mktemp("repo_loc"),
36+
)
37+
yield
38+
if og_stores_config is None:
39+
del dj.config["stores"]
40+
else:
41+
dj.config["stores"] = og_stores_config
42+
43+
44+
@pytest.fixture
45+
def schema_update1(connection_test):
46+
schema = dj.Schema(
47+
PREFIX + "_update1", context=dict(Thing=Thing), connection=connection_test
48+
)
49+
schema(Thing)
50+
yield schema
51+
schema.drop()
52+
53+
54+
def test_update1(tmpdir, enable_filepath_feature, schema_update1, mock_stores_update):
55+
"""Test normal updates"""
56+
# CHECK 1 -- initial insert
57+
key = dict(thing=1)
58+
Thing.insert1(dict(key, frac=0.5))
59+
check1 = Thing.fetch1()
60+
61+
# CHECK 2 -- some updates
62+
# numbers and datetimes
63+
Thing.update1(dict(key, number=3, frac=30, timestamp="2020-01-01 10:00:00"))
64+
# attachment
65+
attach_file = Path(tmpdir, "attach1.dat")
66+
buffer1 = os.urandom(100)
67+
attach_file.write_bytes(buffer1)
68+
Thing.update1(dict(key, picture=attach_file))
69+
attach_file.unlink()
70+
assert not attach_file.is_file()
71+
72+
# filepath
73+
stage_path = dj.config["stores"]["update_repo"]["stage"]
74+
relpath, filename = "one/two/three", "picture.dat"
75+
managed_file = Path(stage_path, relpath, filename)
76+
managed_file.parent.mkdir(parents=True, exist_ok=True)
77+
original_file_data = os.urandom(3000)
78+
with managed_file.open("wb") as f:
79+
f.write(original_file_data)
80+
Thing.update1(dict(key, img_file=managed_file))
81+
managed_file.unlink()
82+
assert not managed_file.is_file()
83+
84+
check2 = Thing.fetch1(download_path=tmpdir)
85+
buffer2 = Path(check2["picture"]).read_bytes() # read attachment
86+
final_file_data = managed_file.read_bytes() # read filepath
87+
88+
# CHECK 3 -- reset to default values using None
89+
Thing.update1(
90+
dict(
91+
key,
92+
number=None,
93+
timestamp=None,
94+
picture=None,
95+
img_file=None,
96+
params=np.random.randn(3, 3),
97+
)
98+
)
99+
check3 = Thing.fetch1()
100+
101+
assert (
102+
check1["number"] == 0 and check1["picture"] is None and check1["params"] is None
103+
)
104+
105+
assert (
106+
check2["number"] == 3
107+
and check2["frac"] == 30.0
108+
and check2["picture"] is not None
109+
and check2["params"] is None
110+
and buffer1 == buffer2
111+
)
112+
113+
assert (
114+
check3["number"] == 0
115+
and check3["frac"] == 30.0
116+
and check3["picture"] is None
117+
and check3["img_file"] is None
118+
and isinstance(check3["params"], np.ndarray)
119+
)
120+
121+
assert check3["timestamp"] > check2["timestamp"]
122+
assert buffer1 == buffer2
123+
assert original_file_data == final_file_data
124+
125+
126+
def test_update1_nonexistent(
127+
enable_filepath_feature, schema_update1, mock_stores_update
128+
):
129+
with pytest.raises(DataJointError):
130+
# updating a non-existent entry
131+
Thing.update1(dict(thing=100, frac=0.5))
132+
133+
134+
def test_update1_noprimary(enable_filepath_feature, schema_update1, mock_stores_update):
135+
with pytest.raises(DataJointError):
136+
# missing primary key
137+
Thing.update1(dict(number=None))
138+
139+
140+
def test_update1_misspelled_attribute(
141+
enable_filepath_feature, schema_update1, mock_stores_update
142+
):
143+
key = dict(thing=17)
144+
Thing.insert1(dict(key, frac=1.5))
145+
with pytest.raises(DataJointError):
146+
# misspelled attribute
147+
Thing.update1(dict(key, numer=3))

0 commit comments

Comments
 (0)