Skip to content

Commit 3d8c1ef

Browse files
committed
Download archives by session
This replaces the previous download of only the latest archives with a means to download the archives which were chosen for a given session.
1 parent c9718ad commit 3d8c1ef

File tree

6 files changed

+188
-41
lines changed

6 files changed

+188
-41
lines changed

code_submitter/extract_archives.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,31 +3,43 @@
33
import asyncio
44
import zipfile
55
import argparse
6+
from typing import cast
67
from pathlib import Path
78

89
import databases
10+
from sqlalchemy.sql import select
911

1012
from . import utils, config
13+
from .tables import Session
1114

1215

13-
async def async_main(output_archive: Path) -> None:
16+
async def async_main(output_archive: Path, session_name: str) -> None:
1417
output_archive.parent.mkdir(parents=True, exist_ok=True)
1518

1619
database = databases.Database(config.DATABASE_URL)
1720

21+
session_id = cast(int, await database.fetch_one(select([
22+
Session.c.id,
23+
]).where(
24+
Session.c.name == session_name,
25+
)))
26+
1827
with zipfile.ZipFile(output_archive) as zf:
1928
async with database.transaction():
20-
utils.collect_submissions(database, zf)
29+
utils.collect_submissions(database, zf, session_id)
2130

2231

2332
def parse_args() -> argparse.Namespace:
2433
parser = argparse.ArgumentParser()
34+
parser.add_argument('session_name', type=str)
2535
parser.add_argument('output_archive', type=Path)
2636
return parser.parse_args()
2737

2838

2939
def main(args: argparse.Namespace) -> None:
30-
asyncio.get_event_loop().run_until_complete(async_main(args.output_archive))
40+
asyncio.get_event_loop().run_until_complete(
41+
async_main(args.output_archive, args.session_name),
42+
)
3143

3244

3345
if __name__ == '__main__':

code_submitter/server.py

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import io
22
import zipfile
3-
import datetime
3+
from typing import cast
44

55
import databases
66
from sqlalchemy.sql import select
@@ -16,7 +16,7 @@
1616

1717
from . import auth, utils, config
1818
from .auth import User, BLUESHIRT_SCOPE
19-
from .tables import Archive, ChoiceHistory
19+
from .tables import Archive, Session, ChoiceHistory
2020

2121
database = databases.Database(config.DATABASE_URL, force_rollback=config.TESTING)
2222
templates = Jinja2Templates(directory='templates')
@@ -49,10 +49,14 @@ async def homepage(request: Request) -> Response:
4949
Archive.c.created.desc(),
5050
),
5151
)
52+
sessions = await database.fetch_all(
53+
Session.select().order_by(Session.c.created.desc()),
54+
)
5255
return templates.TemplateResponse('index.html', {
5356
'request': request,
5457
'chosen': chosen,
5558
'uploads': uploads,
59+
'sessions': sessions,
5660
'BLUESHIRT_SCOPE': BLUESHIRT_SCOPE,
5761
})
5862

@@ -137,14 +141,25 @@ async def create_session(request: Request) -> Response:
137141

138142

139143
@requires(['authenticated', BLUESHIRT_SCOPE])
144+
@database.transaction()
140145
async def download_submissions(request: Request) -> Response:
146+
session_id = cast(int, request.path_params['session_id'])
147+
148+
session = await database.fetch_one(
149+
Session.select().where(Session.c.id == session_id),
150+
)
151+
152+
if session is None:
153+
return Response(
154+
f"{session_id!r} is not a valid session id",
155+
status_code=404,
156+
)
157+
141158
buffer = io.BytesIO()
142159
with zipfile.ZipFile(buffer, mode='w') as zf:
143-
await utils.collect_submissions(database, zf)
160+
await utils.collect_submissions(database, zf, session_id)
144161

145-
filename = 'submissions-{now}.zip'.format(
146-
now=datetime.datetime.now(datetime.timezone.utc),
147-
)
162+
filename = f"submissions-{session['name']}.zip"
148163

149164
return Response(
150165
buffer.getvalue(),
@@ -157,7 +172,11 @@ async def download_submissions(request: Request) -> Response:
157172
Route('/', endpoint=homepage, methods=['GET']),
158173
Route('/upload', endpoint=upload, methods=['POST']),
159174
Route('/create-session', endpoint=create_session, methods=['POST']),
160-
Route('/download-submissions', endpoint=download_submissions, methods=['GET']),
175+
Route(
176+
'/download-submissions/{session_id:int}',
177+
endpoint=download_submissions,
178+
methods=['GET'],
179+
),
161180
]
162181

163182
middleware = [

code_submitter/utils.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,29 +9,24 @@
99

1010
async def get_chosen_submissions(
1111
database: databases.Database,
12+
session_id: int,
1213
) -> Dict[str, Tuple[int, bytes]]:
1314
"""
1415
Return a mapping of teams to their the chosen archive.
1516
"""
1617

17-
# Note: Ideally we'd group by team in SQL, however that doesn't seem to work
18-
# properly -- we don't get the ordering applied before the grouping.
19-
2018
rows = await database.fetch_all(
2119
select([
2220
Archive.c.id,
2321
Archive.c.team,
2422
Archive.c.content,
25-
ChoiceHistory.c.created,
2623
]).select_from(
27-
Archive.join(ChoiceHistory),
28-
).order_by(
29-
Archive.c.team,
30-
ChoiceHistory.c.created.asc(),
24+
Archive.join(ChoiceHistory).join(ChoiceForSession),
25+
).where(
26+
Session.c.id == session_id,
3127
),
3228
)
3329

34-
# Rely on later keys replacing earlier occurrences of the same key.
3530
return {x['team']: (x['id'], x['content']) for x in rows}
3631

3732

@@ -88,8 +83,9 @@ def summarise(submissions: Dict[str, Tuple[int, bytes]]) -> str:
8883
async def collect_submissions(
8984
database: databases.Database,
9085
zipfile: ZipFile,
86+
session_id: int,
9187
) -> None:
92-
submissions = await get_chosen_submissions(database)
88+
submissions = await get_chosen_submissions(database, session_id)
9389

9490
for team, (_, content) in submissions.items():
9591
zipfile.writestr(f'{team.upper()}.zip', content)

templates/index.html

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,14 +31,40 @@
3131
<body>
3232
<div class="container">
3333
<h1>Virtual Competition Code Submission</h1>
34-
{% if BLUESHIRT_SCOPE in request.auth.scopes %}
3534
<div class="row">
3635
<div class="col-sm-6">
37-
<a download href="{{ url_for('download_submissions') }}">
38-
Download current chosen submissions
39-
</a>
36+
<h3>Sessions</h3>
37+
<table class="table table-striped">
38+
<tr>
39+
<th scope="col">Name</th>
40+
<th scope="col">Created</th>
41+
<th scope="col">By</th>
42+
{% if BLUESHIRT_SCOPE in request.auth.scopes %}
43+
<th scope="col">Download</th>
44+
{% endif %}
45+
</tr>
46+
{% for session in sessions %}
47+
<tr>
48+
<td>{{ session.name }}</td>
49+
<td>{{ session.created }}</td>
50+
<td>{{ session.username }}</td>
51+
<!-- TODO: teams in the session -->
52+
{% if BLUESHIRT_SCOPE in request.auth.scopes %}
53+
<td>
54+
<a
55+
download
56+
href="{{ url_for('download_submissions', session_id=session.id) }}"
57+
>
58+
59+
</a>
60+
</td>
61+
{% endif %}
62+
</tr>
63+
{% endfor %}
64+
</table>
4065
</div>
4166
</div>
67+
{% if BLUESHIRT_SCOPE in request.auth.scopes %}
4268
<div class="row">
4369
<div class="col-sm-6">
4470
<form

tests/tests_app.py

Lines changed: 58 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from sqlalchemy.sql import select
88
from starlette.testclient import TestClient
99

10-
from code_submitter.tables import Archive, Session, ChoiceHistory
10+
from code_submitter.tables import (
11+
Archive,
12+
Session,
13+
ChoiceHistory,
14+
ChoiceForSession,
15+
)
1116

1217

1318
class AppTests(test_utils.DatabaseTestCase):
@@ -17,11 +22,11 @@ def setUp(self) -> None:
1722
# App import must happen after TESTING environment setup
1823
from code_submitter.server import app
1924

20-
def url_for(name: str) -> str:
25+
def url_for(name: str, **path_params: str) -> str:
2126
# While it makes for uglier tests, we do need to use more absolute
2227
# paths here so that the urls emitted contain the root_path from the
2328
# ASGI server and in turn work correctly under proxy.
24-
return 'http://testserver{}'.format(app.url_path_for(name))
29+
return 'http://testserver{}'.format(app.url_path_for(name, **path_params))
2530

2631
test_client = TestClient(app)
2732
self.session = test_client.__enter__()
@@ -315,7 +320,13 @@ def test_create_session(self) -> None:
315320
)
316321

317322
def test_no_download_link_for_non_blueshirt(self) -> None:
318-
download_url = self.url_for('download_submissions')
323+
session_id = self.await_(self.database.execute(
324+
Session.insert().values(
325+
name="Test session",
326+
username='blueshirt',
327+
),
328+
))
329+
download_url = self.url_for('download_submissions', session_id=session_id)
319330

320331
response = self.session.get(self.url_for('homepage'))
321332

@@ -325,20 +336,43 @@ def test_no_download_link_for_non_blueshirt(self) -> None:
325336
def test_shows_download_link_for_blueshirt(self) -> None:
326337
self.session.auth = ('blueshirt', 'blueshirt')
327338

328-
download_url = self.url_for('download_submissions')
339+
session_id = self.await_(self.database.execute(
340+
Session.insert().values(
341+
name="Test session",
342+
username='blueshirt',
343+
),
344+
))
345+
download_url = self.url_for('download_submissions', session_id=session_id)
329346

330347
response = self.session.get(self.url_for('homepage'))
331348
html = response.text
332349
self.assertIn(download_url, html)
333350

334351
def test_download_submissions_requires_blueshirt(self) -> None:
335-
response = self.session.get(self.url_for('download_submissions'))
352+
session_id = self.await_(self.database.execute(
353+
Session.insert().values(
354+
name="Test session",
355+
username='blueshirt',
356+
),
357+
))
358+
response = self.session.get(
359+
self.url_for('download_submissions', session_id=session_id),
360+
)
336361
self.assertEqual(403, response.status_code)
337362

338363
def test_download_submissions_when_none(self) -> None:
339364
self.session.auth = ('blueshirt', 'blueshirt')
340365

341-
response = self.session.get(self.url_for('download_submissions'))
366+
session_id = self.await_(self.database.execute(
367+
Session.insert().values(
368+
name="Test session",
369+
username='blueshirt',
370+
),
371+
))
372+
373+
response = self.session.get(
374+
self.url_for('download_submissions', session_id=session_id),
375+
)
342376
self.assertEqual(200, response.status_code)
343377

344378
with zipfile.ZipFile(io.BytesIO(response.content)) as zf:
@@ -359,15 +393,30 @@ def test_download_submissions(self) -> None:
359393
created=datetime.datetime(2020, 8, 8, 12, 0),
360394
),
361395
))
362-
self.await_(self.database.execute(
396+
choice_id = self.await_(self.database.execute(
363397
ChoiceHistory.insert().values(
364398
archive_id=8888888888,
365399
username='test_user',
366400
created=datetime.datetime(2020, 9, 9, 12, 0),
367401
),
368402
))
369403

370-
response = self.session.get(self.url_for('download_submissions'))
404+
session_id = self.await_(self.database.execute(
405+
Session.insert().values(
406+
name="Test session",
407+
username='blueshirt',
408+
),
409+
))
410+
self.await_(self.database.execute(
411+
ChoiceForSession.insert().values(
412+
choice_id=choice_id,
413+
session_id=session_id,
414+
),
415+
))
416+
417+
response = self.session.get(
418+
self.url_for('download_submissions', session_id=session_id),
419+
)
371420
self.assertEqual(200, response.status_code)
372421

373422
with zipfile.ZipFile(io.BytesIO(response.content)) as zf:

0 commit comments

Comments
 (0)