1
- from nose . tools import assert_true , assert_false , assert_equal , raises
1
+ import pytest
2
2
import os
3
3
import numpy as np
4
4
from pathlib import Path
5
5
import tempfile
6
6
import datajoint as dj
7
- from . import PREFIX , CONN_INFO
7
+ from . import PREFIX
8
8
from datajoint import DataJointError
9
9
10
- schema = dj .Schema (PREFIX + "_update1" , connection = dj .conn (** CONN_INFO ))
11
10
12
- dj .config ["stores" ]["update_store" ] = dict (protocol = "file" , location = tempfile .mkdtemp ())
13
-
14
- dj .config ["stores" ]["update_repo" ] = dict (
15
- stage = tempfile .mkdtemp (), protocol = "file" , location = tempfile .mkdtemp ()
16
- )
17
-
18
-
19
- scratch_folder = tempfile .mkdtemp ()
20
-
21
- dj .errors ._switch_filepath_types (True )
22
-
23
-
24
- @schema
25
11
class Thing (dj .Manual ):
26
12
definition = """
27
13
thing : int
@@ -35,10 +21,38 @@ class Thing(dj.Manual):
35
21
"""
36
22
37
23
38
- def test_update1 ():
39
- """test normal updates"""
24
+ @pytest .fixture (scope = "module" )
25
+ def mock_stores_update (tmpdir_factory ):
26
+ og_stores_config = dj .config .get ("stores" )
27
+ if "stores" not in dj .config :
28
+ dj .config ["stores" ] = {}
29
+ dj .config ["stores" ]["update_store" ] = dict (
30
+ protocol = "file" , location = tmpdir_factory .mktemp ("store" )
31
+ )
32
+ dj .config ["stores" ]["update_repo" ] = dict (
33
+ stage = tmpdir_factory .mktemp ("repo_stage" ),
34
+ protocol = "file" ,
35
+ location = tmpdir_factory .mktemp ("repo_loc" ),
36
+ )
37
+ yield
38
+ if og_stores_config is None :
39
+ del dj .config ["stores" ]
40
+ else :
41
+ dj .config ["stores" ] = og_stores_config
40
42
41
- dj .errors ._switch_filepath_types (True )
43
+
44
+ @pytest .fixture
45
+ def schema_update1 (connection_test ):
46
+ schema = dj .Schema (
47
+ PREFIX + "_update1" , context = dict (Thing = Thing ), connection = connection_test
48
+ )
49
+ schema (Thing )
50
+ yield schema
51
+ schema .drop ()
52
+
53
+
54
+ def test_update1 (tmpdir , enable_filepath_feature , schema_update1 , mock_stores_update ):
55
+ """Test normal updates"""
42
56
# CHECK 1 -- initial insert
43
57
key = dict (thing = 1 )
44
58
Thing .insert1 (dict (key , frac = 0.5 ))
@@ -48,7 +62,7 @@ def test_update1():
48
62
# numbers and datetimes
49
63
Thing .update1 (dict (key , number = 3 , frac = 30 , timestamp = "2020-01-01 10:00:00" ))
50
64
# attachment
51
- attach_file = Path (scratch_folder , "attach1.dat" )
65
+ attach_file = Path (tmpdir , "attach1.dat" )
52
66
buffer1 = os .urandom (100 )
53
67
attach_file .write_bytes (buffer1 )
54
68
Thing .update1 (dict (key , picture = attach_file ))
@@ -67,7 +81,7 @@ def test_update1():
67
81
managed_file .unlink ()
68
82
assert not managed_file .is_file ()
69
83
70
- check2 = Thing .fetch1 (download_path = scratch_folder )
84
+ check2 = Thing .fetch1 (download_path = tmpdir )
71
85
buffer2 = Path (check2 ["picture" ]).read_bytes () # read attachment
72
86
final_file_data = managed_file .read_bytes () # read filepath
73
87
@@ -84,37 +98,50 @@ def test_update1():
84
98
)
85
99
check3 = Thing .fetch1 ()
86
100
87
- assert check1 ["number" ] == 0 and check1 ["picture" ] is None and check1 ["params" ] is None
101
+ assert (
102
+ check1 ["number" ] == 0 and check1 ["picture" ] is None and check1 ["params" ] is None
103
+ )
88
104
89
- assert (check2 ["number" ] == 3
105
+ assert (
106
+ check2 ["number" ] == 3
90
107
and check2 ["frac" ] == 30.0
91
108
and check2 ["picture" ] is not None
92
109
and check2 ["params" ] is None
93
- and buffer1 == buffer2 )
110
+ and buffer1 == buffer2
111
+ )
94
112
95
- assert (check3 ["number" ] == 0
113
+ assert (
114
+ check3 ["number" ] == 0
96
115
and check3 ["frac" ] == 30.0
97
116
and check3 ["picture" ] is None
98
117
and check3 ["img_file" ] is None
99
- and isinstance (check3 ["params" ], np .ndarray ))
118
+ and isinstance (check3 ["params" ], np .ndarray )
119
+ )
100
120
101
121
assert check3 ["timestamp" ] > check2 ["timestamp" ]
102
122
assert buffer1 == buffer2
103
123
assert original_file_data == final_file_data
104
124
105
125
106
- @raises (DataJointError )
107
- def test_update1_nonexistent ():
108
- Thing .update1 (dict (thing = 100 , frac = 0.5 )) # updating a non-existent entry
126
+ def test_update1_nonexistent (
127
+ enable_filepath_feature , schema_update1 , mock_stores_update
128
+ ):
129
+ with pytest .raises (DataJointError ):
130
+ # updating a non-existent entry
131
+ Thing .update1 (dict (thing = 100 , frac = 0.5 ))
109
132
110
133
111
- @raises (DataJointError )
112
- def test_update1_noprimary ():
113
- Thing .update1 (dict (number = None )) # missing primary key
134
+ def test_update1_noprimary (enable_filepath_feature , schema_update1 , mock_stores_update ):
135
+ with pytest .raises (DataJointError ):
136
+ # missing primary key
137
+ Thing .update1 (dict (number = None ))
114
138
115
139
116
- @raises (DataJointError )
117
- def test_update1_misspelled_attribute ():
140
+ def test_update1_misspelled_attribute (
141
+ enable_filepath_feature , schema_update1 , mock_stores_update
142
+ ):
118
143
key = dict (thing = 17 )
119
144
Thing .insert1 (dict (key , frac = 1.5 ))
120
- Thing .update1 (dict (key , numer = 3 )) # misspelled attribute
145
+ with pytest .raises (DataJointError ):
146
+ # misspelled attribute
147
+ Thing .update1 (dict (key , numer = 3 ))
0 commit comments