Skip to content

Commit ee655bb

Browse files
authored
Feat/mcp interface (#260)
* add fastapi app * add test for http * add mcp app * add deps * add tests for mcp * `api` -> `server` * properly load dataset * remove unnecessary data models * run code formatter * configure mcp server running * configure http server running * fix typing * fix typing
1 parent deae237 commit ee655bb

File tree

11 files changed

+830
-6
lines changed

11 files changed

+830
-6
lines changed

autointent/_dump_tools/unit_dumpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from peft import PeftModel
1313
from pydantic import BaseModel
1414
from sklearn.base import BaseEstimator
15-
from transformers import ( # type: ignore[attr-defined]
15+
from transformers import (
1616
AutoModelForSequenceClassification,
1717
AutoTokenizer,
1818
PreTrainedModel,

autointent/context/data_handler/_stratification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from numpy import typing as npt
1414
from sklearn.model_selection import train_test_split
1515
from skmultilearn.model_selection import IterativeStratification
16-
from transformers import set_seed # type: ignore[attr-defined]
16+
from transformers import set_seed
1717

1818
from autointent import Dataset
1919
from autointent.custom_types import LabelType

autointent/modules/scoring/_bert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
from datasets import Dataset, DatasetDict
1111
from sklearn.model_selection import train_test_split
12-
from transformers import ( # type: ignore[attr-defined]
12+
from transformers import (
1313
AutoModelForSequenceClassification,
1414
AutoTokenizer,
1515
DataCollatorWithPadding,

autointent/server/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

autointent/server/http.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
"""FastAPI application for AutoIntent pipeline inference."""
2+
3+
import logging
4+
from collections.abc import AsyncGenerator
5+
from contextlib import asynccontextmanager
6+
from functools import lru_cache
7+
from pathlib import Path
8+
9+
from fastapi import FastAPI, HTTPException
10+
from pydantic import BaseModel, Field
11+
from pydantic_settings import BaseSettings, SettingsConfigDict
12+
13+
from autointent import Pipeline
14+
from autointent.custom_types import ListOfLabelsWithOOS
15+
16+
17+
class Settings(BaseSettings):
18+
"""Application settings loaded from environment variables."""
19+
20+
model_config = SettingsConfigDict(env_file=".env", env_prefix="AUTOINTENT_")
21+
path: str = Field(..., description="Path to the optimized pipeline assets")
22+
host: str = "127.0.0.1"
23+
port: int = 8013
24+
25+
26+
class PredictRequest(BaseModel):
27+
"""Request model for the predict endpoint."""
28+
29+
utterances: list[str] = Field(..., description="List of text utterances to classify")
30+
31+
32+
class PredictResponse(BaseModel):
33+
"""Response model for the predict endpoint."""
34+
35+
predictions: ListOfLabelsWithOOS = Field(..., description="List of predicted class labels")
36+
37+
38+
settings = Settings()
39+
logger = logging.getLogger(__name__)
40+
41+
42+
@lru_cache(maxsize=1)
43+
def load_pipeline() -> Pipeline:
44+
"""Load the optimized pipeline from disk."""
45+
pipeline_path = Path(settings.path)
46+
if not pipeline_path.exists():
47+
msg = f"Pipeline path does not exist: {pipeline_path}"
48+
logger.error(msg)
49+
raise HTTPException(status_code=404, detail=msg)
50+
51+
try:
52+
msg = f"Loading pipeline from: {pipeline_path}"
53+
logger.info(msg)
54+
pipeline = Pipeline.load(pipeline_path)
55+
logger.info("Pipeline loaded successfully")
56+
57+
except Exception:
58+
logger.exception("Failed to load pipeline")
59+
raise
60+
else:
61+
return pipeline
62+
63+
64+
@asynccontextmanager
65+
async def lifespan(_: FastAPI) -> AsyncGenerator[None, None]:
66+
"""Load pipe."""
67+
load_pipeline()
68+
yield
69+
70+
71+
app = FastAPI(
72+
title="AutoIntent Pipeline API",
73+
description="API for serving AutoIntent predictions",
74+
version="0.0.1",
75+
lifespan=lifespan,
76+
)
77+
78+
79+
@app.get("/health")
80+
async def health_check() -> dict[str, str]:
81+
"""Health check endpoint."""
82+
return {"status": "healthy"}
83+
84+
85+
@app.post("/predict")
86+
async def predict(request: PredictRequest) -> PredictResponse:
87+
"""Predict class labels for the given utterances.
88+
89+
Args:
90+
request: Request containing list of utterances to classify
91+
92+
Returns:
93+
Response containing predicted class labels
94+
"""
95+
current_pipeline = load_pipeline()
96+
97+
if not request.utterances:
98+
return PredictResponse(predictions=[])
99+
100+
predictions = current_pipeline.predict(request.utterances)
101+
102+
return PredictResponse(predictions=predictions)
103+
104+
105+
def main() -> None:
106+
"""Main entry point for the HTTP server."""
107+
import uvicorn
108+
109+
uvicorn.run(
110+
"autointent.server.http:app",
111+
host=settings.host,
112+
port=settings.port,
113+
reload=False,
114+
)

0 commit comments

Comments
 (0)