-
-
Notifications
You must be signed in to change notification settings - Fork 1
refactor: migrate to fastapi #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,274 +1,67 @@ | ||
| import datetime | ||
| """ | ||
| Main application entry point for the OVOS Persona Server. | ||
|
|
||
| This module initializes the FastAPI application, sets up CORS middleware, | ||
| and includes various API routers for chat, embeddings, Ollama, persona status, | ||
| and mock OpenAI Vector Stores. It now centrally manages the unified SQLite database | ||
| initialization using SQLAlchemy. | ||
| """ | ||
| import json | ||
| import os.path | ||
| import random | ||
| import string | ||
| import time | ||
| from typing import Any | ||
| import os | ||
| from typing import Optional | ||
|
|
||
| from flask import Flask, request | ||
| from ovos_bus_client.session import SessionManager | ||
| from fastapi import FastAPI | ||
| from fastapi.middleware.cors import CORSMiddleware | ||
| from ovos_persona import Persona | ||
|
|
||
| import ovos_persona_server.persona | ||
|
|
||
| def get_app(persona_json): | ||
| app = Flask(__name__) | ||
|
|
||
| with open(persona_json) as f: | ||
| persona = json.load(f) | ||
| persona["name"] = persona.get("name") or os.path.basename(persona_json) | ||
|
|
||
| persona = Persona(persona["name"], persona) | ||
|
|
||
| ####### | ||
| @app.route("/status", methods=["GET"]) | ||
| def status(): | ||
| return {"persona": persona.name, | ||
| "solvers": list(persona.solvers.loaded_modules.keys()), | ||
| "models": {s: persona.config.get(s, {}).get("model") | ||
| for s in persona.solvers.loaded_modules.keys()}} | ||
|
|
||
| ############## | ||
| # OpenAI api compat | ||
| @app.route("/chat/completions", methods=["POST"]) | ||
| def chat_completions(): | ||
| data = request.get_json() | ||
| stream = data.get("stream", False) | ||
| messages = data.get("messages") | ||
|
|
||
| completion_id = "".join(random.choices(string.ascii_letters + string.digits, k=28)) | ||
| completion_timestamp = int(time.time()) | ||
|
|
||
| if not stream: | ||
| return { | ||
| "id": f"chatcmpl-{completion_id}", | ||
| "object": "chat.completion", | ||
| "created": completion_timestamp, | ||
| "model": persona.name, | ||
| "choices": [ | ||
| { | ||
| "index": 0, | ||
| "message": { | ||
| "role": "assistant", | ||
| "content": persona.chat(messages), | ||
| }, | ||
| "finish_reason": "stop", | ||
| } | ||
| ], | ||
| "usage": { | ||
| "prompt_tokens": None, | ||
| "completion_tokens": None, | ||
| "total_tokens": None, | ||
| }, | ||
| } | ||
|
|
||
| def streaming(): | ||
| for chunk in persona.stream(messages): | ||
| completion_data = { | ||
| "id": f"chatcmpl-{completion_id}", | ||
| "object": "chat.completion.chunk", | ||
| "created": completion_timestamp, | ||
| "model": persona.name, | ||
| "choices": [ | ||
| { | ||
| "index": 0, | ||
| "delta": { | ||
| "content": chunk, | ||
| }, | ||
| "finish_reason": None, | ||
| } | ||
| ], | ||
| } | ||
|
|
||
| content = json.dumps(completion_data, separators=(",", ":")) | ||
| yield f"data: {content}\n\n" | ||
| time.sleep(0.1) | ||
|
|
||
| end_completion_data: dict[str, Any] = { | ||
| "id": f"chatcmpl-{completion_id}", | ||
| "object": "chat.completion.chunk", | ||
| "created": completion_timestamp, | ||
| "model": persona.name, | ||
| "choices": [ | ||
| { | ||
| "index": 0, | ||
| "delta": {}, | ||
| "finish_reason": "stop", | ||
| } | ||
| ], | ||
| } | ||
| content = json.dumps(end_completion_data, separators=(",", ":")) | ||
| yield f"data: {content}\n\n" | ||
| def create_persona_app(persona_path: str) -> FastAPI: | ||
| """ | ||
| Creates and configures the FastAPI application for the Persona Server. | ||
|
|
||
| return app.response_class(streaming(), mimetype="text/event-stream") | ||
| Args: | ||
| persona_path (Optional[str]): Optional path to a persona JSON file. | ||
| If provided, it overrides the default | ||
| persona path from settings or environment. | ||
|
|
||
| ############ | ||
| # Ollama api compat | ||
| @app.route("/api/chat", methods=["POST"]) | ||
| def chat(): | ||
| model = request.json.get("model") | ||
| messages = request.json.get("messages") | ||
| tools = request.json.get("tools") | ||
| stream = request.json.get("stream") | ||
| Returns: | ||
| FastAPI: The configured FastAPI application instance. | ||
| """ | ||
|
|
||
| # Format timestamp to the desired format | ||
| completion_timestamp = (datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S') | ||
| + f'.{int(time.time() * 1_000_000) % 1_000_000:06d}Z') | ||
|
|
||
| sess = SessionManager().get() | ||
|
|
||
| if not stream: | ||
| ans = persona.chat(messages, lang=sess.lang, units=sess.system_unit) | ||
| data = { | ||
| "model": persona.name, | ||
| "created_at": completion_timestamp, | ||
| "message": { | ||
| "role": "assistant", | ||
| "content": ans, | ||
| }, | ||
| "done": True | ||
| # "context": [1, 2, 3], | ||
| # "total_duration": 5043500667, | ||
| # "load_duration": 5025959, | ||
| # "prompt_eval_count": 26, | ||
| # "prompt_eval_duration": 325953000, | ||
| # "eval_count": 290, | ||
| # "eval_duration": 4709213000 | ||
| } | ||
| return data | ||
|
|
||
| def streaming(): | ||
| for ans in persona.stream(messages, lang=sess.lang, units=sess.system_unit): | ||
| data = { | ||
| "model": persona.name, | ||
| "created_at": completion_timestamp, | ||
| "message": { | ||
| "role": "assistant", | ||
| "content": ans | ||
| }, | ||
| "done": False, | ||
| # "context": [1, 2, 3], | ||
| # "total_duration": 10706818083, | ||
| # "load_duration": 6338219291, | ||
| # "prompt_eval_count": 26, | ||
| # "prompt_eval_duration": 130079000, | ||
| # "eval_count": 259, | ||
| # "eval_duration": 4232710000 | ||
| } | ||
| content = json.dumps(data) | ||
| yield content + "\n" | ||
|
|
||
| end_completion_data = { | ||
| "model": persona.name, | ||
| "created_at": completion_timestamp, | ||
| "message": { | ||
| "role": "assistant", | ||
| "content": "" | ||
| }, | ||
| "done": True, | ||
| # "context": [1, 2, 3], | ||
| # "total_duration": 10706818083, | ||
| # "load_duration": 6338219291, | ||
| # "prompt_eval_count": 26, | ||
| # "prompt_eval_duration": 130079000, | ||
| # "eval_count": 259, | ||
| # "eval_duration": 4232710000 | ||
| } | ||
| content = json.dumps(end_completion_data) | ||
| yield content + "\n" | ||
|
|
||
| return app.response_class(streaming(), mimetype="application/json") | ||
|
|
||
| @app.route("/api/generate", methods=["POST"]) | ||
| def generate(): | ||
| model = request.json.get("model") | ||
| prompt = request.json.get("prompt") | ||
| suffix = request.json.get("suffix") | ||
| system = request.json.get("system") | ||
| template = request.json.get("template") | ||
| stream = request.json.get("stream") | ||
|
|
||
| sess = SessionManager().get() | ||
| with open(persona_path) as f: | ||
| persona = json.load(f) | ||
| persona["name"] = persona.get("name") or os.path.basename(persona_path) | ||
|
|
||
| messages = [{ | ||
| "role": "user", | ||
| "content": prompt | ||
| }] | ||
| if system: | ||
| messages.insert(0, {"role": "system", "content": system}) | ||
| # TODO - move to dependency injection | ||
| ovos_persona_server.persona.default_persona = persona = Persona(persona["name"], persona) | ||
|
|
||
| # Format timestamp to the desired format | ||
| completion_timestamp = (datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S') | ||
| + f'.{int(time.time() * 1_000_000) % 1_000_000:06d}Z') | ||
| from ovos_persona_server.version import VERSION_MAJOR, VERSION_ALPHA, VERSION_BUILD, VERSION_MINOR | ||
|
|
||
| sess = SessionManager().get() | ||
| version_str = f"{VERSION_MAJOR}.{VERSION_MINOR}.{VERSION_BUILD}" | ||
| if VERSION_ALPHA: | ||
| version_str += f"a{VERSION_ALPHA}" | ||
|
|
||
| if not stream: | ||
| ans = persona.chat(messages, lang=sess.lang, units=sess.system_unit) | ||
| data = { | ||
| "model": persona.name, | ||
| "created_at": completion_timestamp, | ||
| "message": { | ||
| "role": "assistant", | ||
| "content": ans, | ||
| }, | ||
| "done": True | ||
| # "context": [1, 2, 3], | ||
| # "total_duration": 5043500667, | ||
| # "load_duration": 5025959, | ||
| # "prompt_eval_count": 26, | ||
| # "prompt_eval_duration": 325953000, | ||
| # "eval_count": 290, | ||
| # "eval_duration": 4709213000 | ||
| } | ||
| return data | ||
| app = FastAPI(title="OVOS Persona Server", | ||
| description="OpenAI/Ollama compatible API for OVOS Personas and Solvers", | ||
| version=version_str) | ||
|
|
||
| def streaming(): | ||
| for ans in persona.stream(messages, lang=sess.lang, units=sess.system_unit): | ||
| data = { | ||
| "model": persona.name, | ||
| "created_at": completion_timestamp, | ||
| "message": { | ||
| "role": "assistant", | ||
| "content": ans | ||
| }, | ||
| "done": False, | ||
| # "context": [1, 2, 3], | ||
| # "total_duration": 10706818083, | ||
| # "load_duration": 6338219291, | ||
| # "prompt_eval_count": 26, | ||
| # "prompt_eval_duration": 130079000, | ||
| # "eval_count": 259, | ||
| # "eval_duration": 4232710000 | ||
| } | ||
| content = json.dumps(data) | ||
| yield content + "\n" | ||
| app.add_middleware( | ||
| CORSMiddleware, | ||
| allow_origins=["*"], # Allows all origins | ||
| allow_credentials=True, | ||
| allow_methods=["*"], # Allows all methods (GET, POST, etc.) | ||
| allow_headers=["*"], # Allows all headers | ||
| ) | ||
|
Comment on lines
+50
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. CORS: wildcard origins + credentials is unsafe/invalid
Minimal fix: - allow_origins=["*"], # Allows all origins
- allow_credentials=True,
+ allow_origins=os.getenv("OVOS_ALLOWED_ORIGINS", "*").split(","),
+ allow_credentials=False,I can wire this to settings/env and add tests.
🤖 Prompt for AI Agents |
||
|
|
||
| end_completion_data = { | ||
| "model": persona.name, | ||
| "created_at": completion_timestamp, | ||
| "message": { | ||
| "role": "assistant", | ||
| "content": "" | ||
| }, | ||
| "done": True, | ||
| # "context": [1, 2, 3], | ||
| # "total_duration": 10706818083, | ||
| # "load_duration": 6338219291, | ||
| # "prompt_eval_count": 26, | ||
| # "prompt_eval_duration": 130079000, | ||
| # "eval_count": 259, | ||
| # "eval_duration": 4232710000 | ||
| } | ||
| content = json.dumps(end_completion_data) | ||
| yield content + "\n" | ||
| # Include routers for different API functionalities | ||
| # imported here only after the Persona object is loaded | ||
| from ovos_persona_server.chat import chat_router | ||
| from ovos_persona_server.ollama import ollama_router | ||
|
|
||
| return app.response_class(streaming(), mimetype="text/event-stream") | ||
| app.include_router(chat_router) | ||
| app.include_router(ollama_router) | ||
|
|
||
| @app.route("/api/tags", methods=["GET"]) | ||
| def tags(): | ||
| return {"models": [ | ||
| {"name": persona.name, "model": str(persona.solvers.sort_order[0])} | ||
| ]} | ||
|
|
||
| return app | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| @@ -1,19 +1,32 @@ | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| Command-line entry point for running the OVOS Persona Server. | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| This module parses command-line arguments to configure and start | ||||||||||||||||||||||||||||||||||||||
| the FastAPI application using Uvicorn. | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| import argparse | ||||||||||||||||||||||||||||||||||||||
| import os.path | ||||||||||||||||||||||||||||||||||||||
| from typing import Any | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| import uvicorn | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from ovos_persona_server import create_persona_app | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| from ovos_persona_server import get_app | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| def main(): | ||||||||||||||||||||||||||||||||||||||
| parser = argparse.ArgumentParser() | ||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--persona", help="path to persona .json file", required=True) | ||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--host", help="host", default="0.0.0.0") | ||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--port", help="port to run server", default=8337) | ||||||||||||||||||||||||||||||||||||||
| args = parser.parse_args() | ||||||||||||||||||||||||||||||||||||||
| def main() -> None: | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| Main function to parse arguments and start the Persona Server. | ||||||||||||||||||||||||||||||||||||||
| """ | ||||||||||||||||||||||||||||||||||||||
| parser = argparse.ArgumentParser(description="OVOS Persona Server") | ||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--persona", help="Path to persona .json file", default=None, type=str) | ||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--host", help="Host address to bind to", default="0.0.0.0", type=str) | ||||||||||||||||||||||||||||||||||||||
| parser.add_argument("--port", help="Port to run server on", default=8337, type=int) | ||||||||||||||||||||||||||||||||||||||
| args: Any = parser.parse_args() # Using Any for args as argparse.Namespace is dynamic | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| app = get_app(os.path.expanduser(args.persona)) | ||||||||||||||||||||||||||||||||||||||
| app = create_persona_app(args.persona) | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| app.run(host=args.host, port=args.port, debug=False) | ||||||||||||||||||||||||||||||||||||||
| uvicorn.run(app, port=args.port, host=args.host, log_level="debug") | ||||||||||||||||||||||||||||||||||||||
|
Comment on lines
+22
to
+29
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make --persona required to avoid runtime crash.
Apply: - parser.add_argument("--persona", help="Path to persona .json file", default=None, type=str)
+ parser.add_argument("--persona", help="Path to persona .json file", required=True, type=str)Optional:
📝 Committable suggestion
Suggested change
🧰 Tools🪛 Ruff (0.14.0)23-23: Possible binding to all interfaces (S104) 🤖 Prompt for AI Agents |
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
| if __name__ == "__main__": | ||||||||||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Harden persona loading and align signature/docs
Apply:
Also applies to: 33-36
🤖 Prompt for AI Agents