Skip to content

Commit 803deb5

Browse files
authored
Adds session handling in middleware (#41)
* Adds session handling in middleware * Adds submodule
1 parent 972dd54 commit 803deb5

File tree

2 files changed

+13
-10
lines changed

2 files changed

+13
-10
lines changed

app.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from fastapi import FastAPI, HTTPException, responses, status
1+
from fastapi import FastAPI, HTTPException, responses, status, Request
22
from pydantic import BaseModel
33
from typing import Union, Dict, Optional
44

@@ -36,57 +36,60 @@ class ExportWsStatsRequest(BaseModel):
3636
overwrite_weak_supervision: Optional[Union[float, Dict[str, float]]]
3737

3838

39+
@app.middleware("http")
40+
async def handle_db_session(request: Request, call_next):
41+
session_token = general.get_ctx_token()
42+
try:
43+
response = await call_next(request)
44+
finally:
45+
general.remove_and_refresh_session(session_token)
46+
47+
return response
48+
49+
3950
@app.post("/fit_predict")
4051
def weakly_supervise(
4152
request: WeakSupervisionRequest,
4253
) -> responses.PlainTextResponse:
43-
session_token = general.get_ctx_token()
4454
integration.fit_predict(
4555
request.project_id,
4656
request.labeling_task_id,
4757
request.user_id,
4858
request.weak_supervision_task_id,
4959
request.overwrite_weak_supervision,
5060
)
51-
general.remove_and_refresh_session(session_token)
5261
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
5362

5463

5564
@app.post("/labeling_task_statistics")
5665
def calculate_task_stats(
5766
request: TaskStatsRequest,
5867
) -> responses.PlainTextResponse:
59-
session_token = general.get_ctx_token()
6068
stats.calculate_quality_statistics_for_labeling_task(
6169
request.project_id, request.labeling_task_id, request.user_id
6270
)
63-
general.remove_and_refresh_session(session_token)
6471
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
6572

6673

6774
@app.post("/source_statistics")
6875
def calculate_source_stats(
6976
request: SourceStatsRequest,
7077
) -> responses.PlainTextResponse:
71-
session_token = general.get_ctx_token()
7278
has_coverage = stats.calculate_quantity_statistics_for_labeling_task_from_source(
7379
request.project_id, request.source_id, request.user_id
7480
)
7581
if has_coverage:
7682
stats.calculate_quality_statistics_for_source(
7783
request.project_id, request.source_id, request.user_id
7884
)
79-
general.remove_and_refresh_session(session_token)
8085
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
8186

8287

8388
@app.post("/export_ws_stats")
8489
def export_ws_stats(request: ExportWsStatsRequest) -> responses.PlainTextResponse:
85-
session_token = general.get_ctx_token()
8690
status_code, message = integration.export_weak_supervision_stats(
8791
request.project_id, request.labeling_task_id, request.overwrite_weak_supervision
8892
)
89-
general.remove_and_refresh_session(session_token)
9093

9194
if status_code != 200:
9295
raise HTTPException(status_code=status_code, detail=message)

0 commit comments

Comments
 (0)