Skip to content

Commit 20ab185

Browse files
committed
Format with black
1 parent 9614641 commit 20ab185

File tree

1 file changed

+62
-35
lines changed

1 file changed

+62
-35
lines changed

tests/test_update1.py

Lines changed: 62 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,13 @@
1-
from nose.tools import assert_true, assert_false, assert_equal, raises
1+
import pytest
22
import os
33
import numpy as np
44
from pathlib import Path
55
import tempfile
66
import datajoint as dj
7-
from . import PREFIX, CONN_INFO
7+
from . import PREFIX
88
from datajoint import DataJointError
99

10-
schema = dj.Schema(PREFIX + "_update1", connection=dj.conn(**CONN_INFO))
1110

12-
dj.config["stores"]["update_store"] = dict(protocol="file", location=tempfile.mkdtemp())
13-
14-
dj.config["stores"]["update_repo"] = dict(
15-
stage=tempfile.mkdtemp(), protocol="file", location=tempfile.mkdtemp()
16-
)
17-
18-
19-
scratch_folder = tempfile.mkdtemp()
20-
21-
dj.errors._switch_filepath_types(True)
22-
23-
24-
@schema
2511
class Thing(dj.Manual):
2612
definition = """
2713
thing : int
@@ -35,10 +21,38 @@ class Thing(dj.Manual):
3521
"""
3622

3723

38-
def test_update1():
39-
"""test normal updates"""
24+
@pytest.fixture(scope="module")
25+
def mock_stores_update(tmpdir_factory):
26+
og_stores_config = dj.config.get("stores")
27+
if "stores" not in dj.config:
28+
dj.config["stores"] = {}
29+
dj.config["stores"]["update_store"] = dict(
30+
protocol="file", location=tmpdir_factory.mktemp("store")
31+
)
32+
dj.config["stores"]["update_repo"] = dict(
33+
stage=tmpdir_factory.mktemp("repo_stage"),
34+
protocol="file",
35+
location=tmpdir_factory.mktemp("repo_loc"),
36+
)
37+
yield
38+
if og_stores_config is None:
39+
del dj.config["stores"]
40+
else:
41+
dj.config["stores"] = og_stores_config
4042

41-
dj.errors._switch_filepath_types(True)
43+
44+
@pytest.fixture
45+
def schema_update1(connection_test):
46+
schema = dj.Schema(
47+
PREFIX + "_update1", context=dict(Thing=Thing), connection=connection_test
48+
)
49+
schema(Thing)
50+
yield schema
51+
schema.drop()
52+
53+
54+
def test_update1(tmpdir, enable_filepath_feature, schema_update1, mock_stores_update):
55+
"""Test normal updates"""
4256
# CHECK 1 -- initial insert
4357
key = dict(thing=1)
4458
Thing.insert1(dict(key, frac=0.5))
@@ -48,7 +62,7 @@ def test_update1():
4862
# numbers and datetimes
4963
Thing.update1(dict(key, number=3, frac=30, timestamp="2020-01-01 10:00:00"))
5064
# attachment
51-
attach_file = Path(scratch_folder, "attach1.dat")
65+
attach_file = Path(tmpdir, "attach1.dat")
5266
buffer1 = os.urandom(100)
5367
attach_file.write_bytes(buffer1)
5468
Thing.update1(dict(key, picture=attach_file))
@@ -67,7 +81,7 @@ def test_update1():
6781
managed_file.unlink()
6882
assert not managed_file.is_file()
6983

70-
check2 = Thing.fetch1(download_path=scratch_folder)
84+
check2 = Thing.fetch1(download_path=tmpdir)
7185
buffer2 = Path(check2["picture"]).read_bytes() # read attachment
7286
final_file_data = managed_file.read_bytes() # read filepath
7387

@@ -84,37 +98,50 @@ def test_update1():
8498
)
8599
check3 = Thing.fetch1()
86100

87-
assert check1["number"] == 0 and check1["picture"] is None and check1["params"] is None
101+
assert (
102+
check1["number"] == 0 and check1["picture"] is None and check1["params"] is None
103+
)
88104

89-
assert (check2["number"] == 3
105+
assert (
106+
check2["number"] == 3
90107
and check2["frac"] == 30.0
91108
and check2["picture"] is not None
92109
and check2["params"] is None
93-
and buffer1 == buffer2)
110+
and buffer1 == buffer2
111+
)
94112

95-
assert (check3["number"] == 0
113+
assert (
114+
check3["number"] == 0
96115
and check3["frac"] == 30.0
97116
and check3["picture"] is None
98117
and check3["img_file"] is None
99-
and isinstance(check3["params"], np.ndarray))
118+
and isinstance(check3["params"], np.ndarray)
119+
)
100120

101121
assert check3["timestamp"] > check2["timestamp"]
102122
assert buffer1 == buffer2
103123
assert original_file_data == final_file_data
104124

105125

106-
@raises(DataJointError)
107-
def test_update1_nonexistent():
108-
Thing.update1(dict(thing=100, frac=0.5)) # updating a non-existent entry
126+
def test_update1_nonexistent(
127+
enable_filepath_feature, schema_update1, mock_stores_update
128+
):
129+
with pytest.raises(DataJointError):
130+
# updating a non-existent entry
131+
Thing.update1(dict(thing=100, frac=0.5))
109132

110133

111-
@raises(DataJointError)
112-
def test_update1_noprimary():
113-
Thing.update1(dict(number=None)) # missing primary key
134+
def test_update1_noprimary(enable_filepath_feature, schema_update1, mock_stores_update):
135+
with pytest.raises(DataJointError):
136+
# missing primary key
137+
Thing.update1(dict(number=None))
114138

115139

116-
@raises(DataJointError)
117-
def test_update1_misspelled_attribute():
140+
def test_update1_misspelled_attribute(
141+
enable_filepath_feature, schema_update1, mock_stores_update
142+
):
118143
key = dict(thing=17)
119144
Thing.insert1(dict(key, frac=1.5))
120-
Thing.update1(dict(key, numer=3)) # misspelled attribute
145+
with pytest.raises(DataJointError):
146+
# misspelled attribute
147+
Thing.update1(dict(key, numer=3))

0 commit comments

Comments
 (0)