-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathmain.py
More file actions
32 lines (24 loc) · 1.07 KB
/
main.py
File metadata and controls
32 lines (24 loc) · 1.07 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
import pickle
import re
from fastapi import FastAPI
from models.french_article import FrenchArticle
import models.ml.classifier as clf
from config import Config
app = FastAPI(title="News Classification ML API", description="API for french news dataset ml model", version="1.0")
@app.on_event('startup')
def load_model():
clf.tfidf = pickle.load(open(Config.tfidf_pipeline, 'rb'))
clf.model = pickle.load(open(Config.model, 'rb'))
@app.post('/predict', tags=["predictions"])
async def get_prediction(french_article: FrenchArticle):
text = french_article.data[0]
text = re.sub('[^a-zA-Zа-яА-Я1-9]+', ' ', text)
text = re.sub(' +', ' ', text)
features = clf.tfidf.transform([text])
prediction = clf.model.predict(features).tolist()
prediction = Config.id_to_category[prediction[0]]
log_proba = clf.model.predict_proba(features).tolist()[0]
print(log_proba, len(log_proba), log_proba)
log_proba = {Config.id_to_category[i]: prob for i, prob in enumerate(log_proba)}
return {"prediction": prediction,
"log_proba": log_proba}