Skip to content

Commit 4da415c

Browse files
committed
Extract base test class for testing with a database connection
1 parent 1fb192c commit 4da415c

File tree

2 files changed

+53
-32
lines changed

2 files changed

+53
-32
lines changed

tests/test_utils.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,46 @@
1+
import json
12
import asyncio
3+
import tempfile
24
import unittest
3-
from typing import TypeVar, Awaitable
5+
from typing import IO, TypeVar, Awaitable
6+
7+
import alembic # type: ignore[import]
8+
import databases
9+
from sqlalchemy import create_engine
10+
from alembic.config import Config # type: ignore[import]
11+
from starlette.config import environ
412

513
T = TypeVar('T')
614

715

16+
DATABASE_FILE: IO[bytes]
17+
18+
19+
def ensure_database_configured() -> None:
20+
global DATABASE_FILE
21+
22+
try:
23+
DATABASE_FILE
24+
return
25+
except NameError:
26+
pass
27+
28+
DATABASE_FILE = tempfile.NamedTemporaryFile(suffix='sqlite.db')
29+
url = 'sqlite:///{}'.format(DATABASE_FILE.name)
30+
31+
environ['TESTING'] = 'True'
32+
environ['DATABASE_URL'] = url
33+
34+
environ['AUTH_BACKEND'] = json.dumps({
35+
'backend': 'code_submitter.auth.DummyBackend',
36+
'kwargs': {'team': 'SRZ2'},
37+
})
38+
39+
create_engine(url)
40+
41+
alembic.command.upgrade(Config('alembic.ini'), 'head')
42+
43+
844
class AsyncTestCase(unittest.TestCase):
945
def await_(self, awaitable: Awaitable[T]) -> T:
1046
return self.loop.run_until_complete(awaitable)
@@ -13,3 +49,17 @@ def setUp(self) -> None:
1349
super().setUp()
1450

1551
self.loop = asyncio.get_event_loop()
52+
53+
54+
class DatabaseTestCase(AsyncTestCase):
55+
database: databases.Database
56+
57+
@classmethod
58+
def setUpClass(cls) -> None:
59+
super().setUpClass()
60+
ensure_database_configured()
61+
62+
# Import must happen after TESTING environment setup
63+
from code_submitter.server import database
64+
65+
cls.database = database

tests/tests_app.py

Lines changed: 2 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,47 +1,19 @@
11
import io
2-
import json
32
import zipfile
43
import datetime
5-
import tempfile
6-
from typing import IO
74
from unittest import mock
85

9-
import alembic # type: ignore[import]
106
import test_utils
11-
from sqlalchemy import create_engine
12-
from alembic.config import Config # type: ignore[import]
13-
from starlette.config import environ
147
from starlette.testclient import TestClient
158
from code_submitter.tables import Archive, ChoiceHistory
169

17-
DATABASE_FILE: IO[bytes]
1810

19-
20-
def setUpModule() -> None:
21-
global DATABASE_FILE
22-
23-
DATABASE_FILE = tempfile.NamedTemporaryFile(suffix='sqlite.db')
24-
url = 'sqlite:///{}'.format(DATABASE_FILE.name)
25-
26-
environ['TESTING'] = 'True'
27-
environ['DATABASE_URL'] = url
28-
29-
environ['AUTH_BACKEND'] = json.dumps({
30-
'backend': 'code_submitter.auth.DummyBackend',
31-
'kwargs': {'team': 'SRZ2'},
32-
})
33-
34-
create_engine(url)
35-
36-
alembic.command.upgrade(Config('alembic.ini'), 'head')
37-
38-
39-
class AppTests(test_utils.AsyncTestCase):
11+
class AppTests(test_utils.DatabaseTestCase):
4012
def setUp(self) -> None:
4113
super().setUp()
4214

4315
# App import must happen after TESTING environment setup
44-
from code_submitter.server import app, database
16+
from code_submitter.server import app
4517

4618
def url_for(name: str) -> str:
4719
# While it makes for uglier tests, we do need to use more absolute
@@ -53,7 +25,6 @@ def url_for(name: str) -> str:
5325
self.session = test_client.__enter__()
5426
self.session.auth = ('test_user', 'test_pass')
5527
self.url_for = url_for
56-
self.database = database
5728

5829
def tearDown(self) -> None:
5930
self.session.__exit__(None, None, None)

0 commit comments

Comments
 (0)