Skip to content

Commit 106162f

Browse files
JWittmeyerJWittmeyer
andauthored
Fast Api JSONResponst change (#31)
* Adds response classes * config change * Doc ock change --------- Co-authored-by: JWittmeyer <[email protected]>
1 parent 2f3c672 commit 106162f

File tree

4 files changed

+43
-34
lines changed

4 files changed

+43
-34
lines changed

app.py

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# -*- coding: utf-8 -*-
2-
from fastapi import FastAPI
2+
from fastapi import FastAPI, responses, status
33
import controller
44
from data import data_type
55
from typing import List, Dict, Tuple
@@ -25,7 +25,7 @@
2525
@app.get("/classification/recommend/{data_type}")
2626
def recommendations(
2727
data_type: str,
28-
) -> Tuple[List[Dict[str, str]], int]:
28+
) -> responses.JSONResponse:
2929
recommends = [
3030
### English ###
3131
{
@@ -92,39 +92,44 @@ def recommendations(
9292
},
9393
]
9494

95-
return recommends, 200
95+
return responses.JSONResponse(status_code=status.HTTP_200_OK, content=recommends)
9696

9797

9898
@app.post("/classification/encode")
99-
def encode_classification(request: data_type.Request) -> Tuple[int, str]:
99+
def encode_classification(request: data_type.Request) -> responses.PlainTextResponse:
100100
# session logic for threads in side
101-
return controller.start_encoding_thread(request, "classification"), ""
101+
status_code = controller.start_encoding_thread(request, "classification")
102+
103+
return responses.PlainTextResponse(status_code=status_code)
102104

103105

104106
@app.post("/extraction/encode")
105-
def encode_extraction(request: data_type.Request) -> Tuple[int, str]:
107+
def encode_extraction(request: data_type.Request) -> responses.PlainTextResponse:
106108
# session logic for threads in side
107-
return controller.start_encoding_thread(request, "extraction"), ""
109+
status_code = controller.start_encoding_thread(request, "extraction")
110+
return responses.PlainTextResponse(status_code=status_code)
108111

109112

110113
@app.delete("/delete/{project_id}/{embedding_id}")
111-
def delete_embedding(project_id: str, embedding_id: str) -> Tuple[int, str]:
114+
def delete_embedding(project_id: str, embedding_id: str) -> responses.PlainTextResponse:
112115
session_token = general.get_ctx_token()
113-
return_value = controller.delete_embedding(project_id, embedding_id)
116+
status_code = controller.delete_embedding(project_id, embedding_id)
114117
general.remove_and_refresh_session(session_token)
115-
return return_value, ""
118+
return responses.PlainTextResponse(status_code=status_code)
116119

117120

118121
@app.post("/upload_tensor_data/{project_id}/{embedding_id}")
119-
def upload_tensor_data(project_id: str, embedding_id: str) -> Tuple[int, str]:
122+
def upload_tensor_data(
123+
project_id: str, embedding_id: str
124+
) -> responses.PlainTextResponse:
120125
session_token = general.get_ctx_token()
121126
controller.upload_embedding_as_file(project_id, embedding_id)
122127
request_util.post_embedding_to_neural_search(project_id, embedding_id)
123128
general.remove_and_refresh_session(session_token)
124-
return 200, ""
129+
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)
125130

126131

127132
@app.put("/config_changed")
128-
def config_changed() -> int:
133+
def config_changed() -> responses.PlainTextResponse:
129134
config_handler.refresh_config()
130-
return 200
135+
return responses.PlainTextResponse(status_code=status.HTTP_200_OK)

controller.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
notification,
1111
organization,
1212
)
13+
from fastapi import status
1314
import pickle
1415
import torch
1516
import traceback
@@ -85,7 +86,7 @@ def get_docbins(
8586
def start_encoding_thread(request: data_type.Request, embedding_type: str) -> int:
8687
doc_ock.post_embedding_creation(request.user_id, request.config_string)
8788
daemon.run(prepare_run_encoding, request, embedding_type)
88-
return 200
89+
return status.HTTP_200_OK
8990

9091

9192
def prepare_run_encoding(request: data_type.Request, embedding_type: str) -> int:
@@ -215,7 +216,7 @@ def run_encoding(
215216
send_project_update(
216217
request.project_id, f"notification_created:{request.user_id}", True
217218
)
218-
return 422
219+
return status.HTTP_422_UNPROCESSABLE_ENTITY
219220
except ValueError:
220221
embedding.update_embedding_state_failed(
221222
request.project_id,
@@ -239,7 +240,7 @@ def run_encoding(
239240
send_project_update(
240241
request.project_id, f"notification_created:{request.user_id}", True
241242
)
242-
return 422
243+
return status.HTTP_422_UNPROCESSABLE_ENTITY
243244

244245
if not embedder:
245246
embedding.update_embedding_state_failed(
@@ -288,7 +289,7 @@ def run_encoding(
288289
f"embedding:{embedding_id}:state:{enums.EmbeddingState.FAILED.value}",
289290
)
290291
doc_ock.post_embedding_failed(request.user_id, request.config_string)
291-
return 422
292+
return status.HTTP_422_UNPROCESSABLE_ENTITY
292293

293294
try:
294295
record_ids, attribute_values_raw = record.get_attribute_data(
@@ -410,7 +411,7 @@ def run_encoding(
410411
)
411412
print(traceback.format_exc(), flush=True)
412413
doc_ock.post_embedding_failed(request.user_id, request.config_string)
413-
return 500
414+
return status.HTTP_500_INTERNAL_SERVER_ERROR
414415

415416
if embedding.get(request.project_id, embedding_id):
416417
for warning_type, idx_list in embedder.get_warnings().items():
@@ -484,7 +485,7 @@ def run_encoding(
484485
doc_ock.post_embedding_finished(request.user_id, request.config_string)
485486
general.commit()
486487
general.remove_and_refresh_session(session_token)
487-
return 200
488+
return status.HTTP_200_OK
488489

489490

490491
def delete_embedding(project_id: str, embedding_id: str) -> int:
@@ -494,12 +495,12 @@ def delete_embedding(project_id: str, embedding_id: str) -> int:
494495
object_name = f"embedding_tensors_{embedding_id}.csv.bz2"
495496

496497
org_id = organization.get_id_by_project_id(project_id)
497-
s3.delete_object(org_id, project_id + "/" + object_name)
498+
s3.delete_object(org_id, f"{project_id}/{object_name}")
498499
request_util.delete_embedding_from_neural_search(embedding_id)
499500
pickle_path = os.path.join("/inference", project_id, f"embedder-{embedding_id}.pkl")
500501
if os.path.exists(pickle_path):
501502
os.remove(pickle_path)
502-
return 200
503+
return status.HTTP_200_OK
503504

504505

505506
@param_throttle(seconds=5)

data/doc_ock.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,13 @@ def _post_event(user_id: str, config_string: str, state: str) -> Any:
4949
"State": state,
5050
"Host": os.getenv("S3_ENDPOINT"),
5151
}
52+
5253
response = requests.post(url, json=data)
53-
if response.status_code == 200:
54-
result, _ = response.json()
55-
return result
56-
else:
54+
55+
if response.status_code != 200:
5756
raise Exception("Could not send data to Doc Ock")
57+
58+
if response.headers.get("content-type") == "application/json":
59+
return response.json()
60+
else:
61+
return response.text

util/config_handler.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,21 @@ def __get_config() -> Dict[str, Any]:
2020

2121
def refresh_config():
2222
response = requests.get(REQUEST_URL)
23-
if response.status_code == 200:
24-
global __config
25-
__config = json.loads(json.loads(response.text))
26-
daemon.run(invalidate_after, 3600) # one hour
27-
else:
28-
raise Exception(
23+
if response.status_code != 200:
24+
raise ValueError(
2925
f"Config service cant be reached -- response.code{response.status_code}"
3026
)
27+
global __config
28+
__config = response.json()
29+
daemon.run(invalidate_after, 3600) # one hour
3130

3231

3332
def get_config_value(
3433
key: str, subkey: Optional[str] = None
3534
) -> Union[str, Dict[str, str]]:
3635
config = __get_config()
3736
if key not in config:
38-
raise Exception(f"Key {key} coudn't be found in config")
37+
raise ValueError(f"Key {key} coudn't be found in config")
3938
value = config[key]
4039

4140
if not subkey:
@@ -44,7 +43,7 @@ def get_config_value(
4443
if isinstance(value, dict) and subkey in value:
4544
return value[subkey]
4645
else:
47-
raise Exception(f"Subkey {subkey} coudn't be found in config[{key}]")
46+
raise ValueError(f"Subkey {subkey} coudn't be found in config[{key}]")
4847

4948

5049
def invalidate_after(sec: int) -> None:

0 commit comments

Comments
 (0)