Skip to content

Commit 49670c9

Browse files
committed
Use pytest tmpdir fixtures for mocking stores
1 parent cdb4052 commit 49670c9

File tree

1 file changed

+28
-10
lines changed

1 file changed

+28
-10
lines changed

tests/conftest.py

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010
import networkx as nx
1111
import json
1212
from pathlib import Path
13-
import tempfile
1413
from datajoint import errors
1514
from datajoint.errors import ADAPTED_TYPE_SWITCH, FILEPATH_FEATURE_SWITCH
1615
from . import (
@@ -176,23 +175,45 @@ def connection_test(connection_root):
176175

177176

178177
@pytest.fixture(scope="session")
179-
def stores_config():
178+
def stores_config(tmpdir_factory):
180179
stores_config = {
181-
"raw": dict(protocol="file", location=tempfile.mkdtemp()),
180+
"raw": dict(protocol="file", location=tmpdir_factory.mktemp("raw")),
182181
"repo": dict(
183-
stage=tempfile.mkdtemp(), protocol="file", location=tempfile.mkdtemp()
182+
stage=tmpdir_factory.mktemp("repo"), protocol="file", location=tmpdir_factory.mktemp("repo")
184183
),
185184
"repo-s3": dict(
186-
S3_CONN_INFO, protocol="s3", location="dj/repo", stage=tempfile.mkdtemp()
185+
S3_CONN_INFO, protocol="s3", location="dj/repo", stage=tmpdir_factory.mktemp("repo-s3")
187186
),
188-
"local": dict(protocol="file", location=tempfile.mkdtemp(), subfolding=(1, 1)),
187+
"local": dict(protocol="file", location=tmpdir_factory.mktemp("local"), subfolding=(1, 1)),
189188
"share": dict(
190189
S3_CONN_INFO, protocol="s3", location="dj/store/repo", subfolding=(2, 4)
191190
),
192191
}
193192
return stores_config
194193

195194

195+
@pytest.fixture
196+
def mock_stores(stores_config):
197+
og_stores_config = dj.config.get("stores")
198+
dj.config["stores"] = stores_config
199+
yield
200+
if og_stores_config is None:
201+
del dj.config["stores"]
202+
else:
203+
dj.config["stores"] = og_stores_config
204+
205+
206+
@pytest.fixture
207+
def mock_cache(tmpdir_factory):
208+
og_cache = dj.config.get("cache")
209+
dj.config["cache"] = tmpdir_factory.mktemp("cache")
210+
yield
211+
if og_cache is None:
212+
del dj.config["cache"]
213+
else:
214+
dj.config["cache"] = og_cache
215+
216+
196217
@pytest.fixture
197218
def schema_any(connection_test):
198219
schema_any = dj.Schema(
@@ -287,15 +308,12 @@ def schema_adv(connection_test):
287308

288309

289310
@pytest.fixture
290-
def schema_ext(connection_test, stores_config, enable_filepath_feature):
311+
def schema_ext(connection_test, enable_filepath_feature, mock_stores, mock_cache):
291312
schema = dj.Schema(
292313
PREFIX + "_extern",
293314
context=schema_external.LOCALS_EXTERNAL,
294315
connection=connection_test,
295316
)
296-
dj.config["stores"] = stores_config
297-
dj.config["cache"] = tempfile.mkdtemp()
298-
299317
schema(schema_external.Simple)
300318
schema(schema_external.SimpleRemote)
301319
schema(schema_external.Seed)

0 commit comments

Comments
 (0)