|
3 | 3 | import secrets |
4 | 4 | import logging |
5 | 5 | import os |
6 | | - |
7 | 6 | from pydantic import BaseModel |
8 | 7 | from transformers import pipeline |
9 | 8 | from fastapi import FastAPI, HTTPException, status, Depends |
|
12 | 11 | from dotenv import load_dotenv |
13 | 12 | from cachier import cachier |
14 | 13 |
|
| 14 | +# ------------------ SETUP ------------------ |
| 15 | + |
| 16 | +# Load environment variables |
15 | 17 | load_dotenv() |
16 | 18 |
|
| 19 | +# Create FastAPI instance |
17 | 20 | app = FastAPI(docs_url=None, redoc_url=None, openapi_url=None) |
18 | 21 |
|
19 | | -# auth with a bearer api key, whose hash is stored in the environment variable API_KEY_HASH |
| 22 | +# Setup logging |
| 23 | +logging.basicConfig(level=logging.INFO) |
| 24 | +logger = logging.getLogger(__name__) |
| 25 | + |
| 26 | +# ------------------ AUTHENTICATION ------------------ |
| 27 | + |
| 28 | +# Auth with a bearer api key, whose hash is stored in the environment variable API_KEY_HASH |
20 | 29 | oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
21 | 30 | API_KEY_HASH = os.getenv("API_KEY_HASH") |
22 | 31 | if not API_KEY_HASH and os.path.exists("/run/secrets/api_key_hash"): |
23 | 32 | with open("/run/secrets/api_key_hash", "r") as f: |
24 | 33 | API_KEY_HASH = f.read().strip() |
| 34 | + logger.info("API key hash loaded from secret") |
| 35 | +else: |
| 36 | + logger.info("API key hash loaded from environment variable") |
25 | 37 |
|
26 | 38 | assert API_KEY_HASH, "API_KEY_HASH must be set" |
27 | 39 |
|
| 40 | + |
| 41 | +# Function to verify API key |
| 42 | +def verify_api_key(token: str): |
| 43 | + token_hash: str = hashlib.sha256(token.encode()).hexdigest() |
| 44 | + return secrets.compare_digest(token_hash, API_KEY_HASH) |
| 45 | + |
| 46 | + |
| 47 | +# Dependency to authenticate user |
| 48 | +async def authenticate_user(token: str = Depends(oauth2_scheme)): |
| 49 | + if not verify_api_key(token): |
| 50 | + raise HTTPException( |
| 51 | + status_code=status.HTTP_401_UNAUTHORIZED, |
| 52 | + detail="Invalid API Key", |
| 53 | + headers={"WWW-Authenticate": "Bearer"}, |
| 54 | + ) |
| 55 | + return token |
| 56 | + |
| 57 | + |
| 58 | +# ------------------ CLASSIFICATION ------------------ |
| 59 | + |
| 60 | +# Setup classifier |
28 | 61 | classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli") |
29 | 62 |
|
| 63 | +# Default labels |
30 | 64 | DEFAULT_LABELS: list[str] = [ |
31 | 65 | "programming", |
32 | 66 | "politics", |
|
36 | 70 | "video games", |
37 | 71 | ] |
38 | 72 |
|
39 | | -pool = ThreadPoolExecutor(max_workers=1) |
40 | | - |
41 | | -logging.basicConfig(level=logging.INFO) |
42 | | - |
43 | 73 |
|
| 74 | +# Classification model |
44 | 75 | class Classification(BaseModel): |
45 | 76 | sequence: str = "The text to classify" |
46 | 77 | labels: list[str] = DEFAULT_LABELS |
47 | 78 | scores: list[float] = [0.0] * len(DEFAULT_LABELS) |
48 | 79 |
|
49 | 80 |
|
| 81 | +# Function to classify message |
50 | 82 | @cachier(cache_dir="./cache") |
51 | 83 | def classify_sync(message: str, labels: list[str]) -> dict: |
52 | 84 | result = classifier(message, candidate_labels=labels) |
53 | 85 | return result |
54 | 86 |
|
55 | 87 |
|
56 | | -# setup auth |
57 | | -def verify_api_key(token: str): |
58 | | - token_hash: str = hashlib.sha256(token.encode()).hexdigest() |
59 | | - return secrets.compare_digest(token_hash, API_KEY_HASH) |
60 | | - |
61 | | - |
62 | | -async def authenticate_user(token: str = Depends(oauth2_scheme)): |
63 | | - if not verify_api_key(token): |
64 | | - raise HTTPException( |
65 | | - status_code=status.HTTP_401_UNAUTHORIZED, |
66 | | - detail="Invalid API Key", |
67 | | - headers={"WWW-Authenticate": "Bearer"}, |
68 | | - ) |
69 | | - return token |
70 | | - |
| 88 | +# ------------------ ROUTES ------------------ |
71 | 89 |
|
72 | | -classification_lock = asyncio.Lock() # Ensure only one classification at a time |
| 90 | +# Lock to ensure only one classification at a time |
| 91 | +classification_lock = asyncio.Lock() |
73 | 92 |
|
74 | 93 |
|
| 94 | +# Route to classify message |
75 | 95 | @app.get("/v1/classify") |
76 | 96 | async def classify( |
77 | 97 | message: str, labels: list[str] = None, token: str = Depends(authenticate_user) |
78 | 98 | ) -> Classification: |
79 | | - """ |
80 | | - Classify the message into one of the labels |
81 | | - :param message: The message to classify |
82 | | - :type message: str |
83 | | - :param labels: The labels to classify the message into |
84 | | - :type labels: list[str] |
85 | | - :return: The classification result |
86 | | - :rtype: Classification |
87 | | - """ |
88 | 99 | labels = labels or DEFAULT_LABELS |
89 | | - async with classification_lock: |
| 100 | + async with classification_lock: # Ensure only one classification at a time |
90 | 101 | loop = asyncio.get_event_loop() |
91 | 102 | result = await loop.run_in_executor(None, classify_sync, message, labels) |
92 | 103 | result = Classification(**result) |
93 | 104 | return result |
94 | 105 |
|
95 | 106 |
|
| 107 | +# Health check route |
96 | 108 | @app.get("/v1/health") |
97 | 109 | async def health() -> dict: |
98 | 110 | return {"status": "ok"} |
0 commit comments