Skip to content

Commit 4a397e0

Browse files
authored
Initial inference documentation pass (#3330)
In the recent meeting it was raised that the codebase is not documented and hard to understand as a result. This adds some initial documentation. For now it is quite basic but hopefully will allow a bit more understanding of some parts of the code "at a glance", reducing the friction somewhat.
1 parent 960dd2e commit 4a397e0

File tree

18 files changed

+113
-7
lines changed

18 files changed

+113
-7
lines changed

inference/safety/main.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,8 @@
1-
# A FastAPI server to run the safety pipeline
1+
"""
2+
A simple FastAPI server which serves a `blade2blade2` safety model.
3+
4+
See https://github.com/LAION-AI/blade2blade for context.
5+
"""
26

37
import asyncio
48
from concurrent.futures import ThreadPoolExecutor
@@ -40,6 +44,7 @@ async def load_pipeline():
4044

4145

4246
async def async_predict(pipeline: Blade2Blade, inputs: str):
47+
"""Run predictions in a separate thread for a small server parallelism benefit."""
4348
return await asyncio.get_event_loop().run_in_executor(executor, pipeline.predict, inputs)
4449

4550

inference/safety/settings.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33

44
class Settings(pydantic.BaseSettings):
5+
# HuggingFace model ID for the model to load in blade2blade
56
safety_model_name: str = "shahules786/blade2blade-t5-base"
67

78

inference/server/export.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Script to facilitate exporting chat data from the server database."""
2+
13
import argparse
24
import asyncio
35
import contextlib

inference/server/main.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ def terminate_server(signum, frame):
5656

5757
@app.on_event("startup")
5858
async def alembic_upgrade():
59+
"""Upgrades database schema based on Alembic migration scripts."""
5960
signal.signal(signal.SIGINT, terminate_server)
6061
if not settings.update_alembic:
6162
logger.warning("Skipping alembic upgrade on startup (update_alembic is False)")
@@ -113,7 +114,7 @@ async def maybe_add_debug_api_keys():
113114
app.include_router(workers.router)
114115
app.include_router(configs.router)
115116

116-
# mount plugins
117+
# mount builtin plugins to be hosted on this server
117118
for app_prefix, sub_app in plugins.plugin_apps.items():
118119
app.mount(path=settings.plugins_path_prefix + app_prefix, app=sub_app)
119120

inference/server/oasst_inference_server/admin.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Logic related to admin actions."""
2+
13
import fastapi
24
from loguru import logger
35
from oasst_inference_server import database, models

inference/server/oasst_inference_server/auth.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Logic related to authorization actions."""
2+
13
import hashlib
24
import json
35
from datetime import datetime, timedelta

inference/server/oasst_inference_server/chat_repository.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313

1414
class ChatRepository(pydantic.BaseModel):
15+
"""Wrapper around a database session providing functionality relating to chats."""
16+
1517
session: database.AsyncSession
1618

1719
class Config:
@@ -38,6 +40,10 @@ async def get_prompter_message_by_id(self, message_id: str) -> models.DbMessage:
3840
async def start_work(
3941
self, *, message_id: str, worker_id: str, worker_config: inference.WorkerConfig
4042
) -> models.DbMessage:
43+
"""
44+
Update an assistant message in the database to be allocated to a specific worker.
45+
The message must be in `pending` state. An exception is raised if the message has timed out or was cancelled.
46+
"""
4147
logger.debug(f"Starting work on message {message_id}")
4248
message = await self.get_assistant_message_by_id(message_id)
4349

@@ -65,6 +71,10 @@ async def start_work(
6571
return message
6672

6773
async def reset_work(self, message_id: str) -> models.DbMessage:
74+
"""
75+
Update an assistant message in the database which has already been allocated to a worker to remove the
76+
allocation and reset the message state to `pending`.
77+
"""
6878
logger.warning(f"Resetting work on message {message_id}")
6979
message = await self.get_assistant_message_by_id(message_id)
7080
message.state = inference.MessageState.pending
@@ -78,6 +88,7 @@ async def reset_work(self, message_id: str) -> models.DbMessage:
7888
return message
7989

8090
async def abort_work(self, message_id: str, reason: str) -> models.DbMessage:
91+
"""Update an assistant message in the database to mark it as having been aborted by the allocated worker."""
8192
logger.warning(f"Aborting work on message {message_id}")
8293
message = await self.get_assistant_message_by_id(message_id)
8394
message.state = inference.MessageState.aborted_by_worker
@@ -88,7 +99,13 @@ async def abort_work(self, message_id: str, reason: str) -> models.DbMessage:
8899
await self.session.refresh(message)
89100
return message
90101

91-
async def complete_work(self, message_id: str, content: str, used_plugin: inference.PluginUsed) -> models.DbMessage:
102+
async def complete_work(
103+
self, message_id: str, content: str, used_plugin: inference.PluginUsed | None
104+
) -> models.DbMessage:
105+
"""
106+
Update an assistant message in the database to mark it as having been completed with the given content, also
107+
updating the used plugin if one is specified.
108+
"""
92109
logger.debug(f"Completing work on message {message_id}")
93110
message = await self.get_assistant_message_by_id(message_id)
94111
message.state = inference.MessageState.complete

inference/server/oasst_inference_server/chat_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44

55
def get_model_config(model_config_name: str) -> model_configs.ModelConfig:
6+
"""Get a `ModelConfig` by its name. See `oasst_shared.model_configs`."""
67
if settings.allowed_model_config_names != "*":
78
if model_config_name not in settings.allowed_model_config_names_list:
89
raise ValueError(f"Model {model_config_name} not in allowed models: {settings.allowed_model_config_names}")

inference/server/oasst_inference_server/compliance.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
"""Logic related to worker compliance checks, which seek to ensure workers do not produce malicious responses."""
2+
13
import datetime
24
from typing import cast
35

@@ -14,6 +16,10 @@
1416
async def find_compliance_work_request_message(
1517
session: database.AsyncSession, worker_config: inference.WorkerConfig, worker_id: str
1618
) -> models.DbMessage | None:
19+
"""
20+
Find a suitable assistant message to carry out a worker compliance check for the given worker. Such a message must
21+
have been generated by a different worker, but one with the same compatibility hash as the given worker.
22+
"""
1723
compat_hash = worker_config.compat_hash
1824
query = (
1925
sqlmodel.select(models.DbMessage)
@@ -30,6 +36,10 @@ async def find_compliance_work_request_message(
3036

3137

3238
async def should_do_compliance_check(session: database.AsyncSession, worker_id: str) -> bool:
39+
"""
40+
Check whether we should carry out a compliance check for the given worker, based on time since last check.
41+
Trusted workers are excluded.
42+
"""
3343
worker = await worker_utils.get_worker(worker_id, session)
3444
if worker.trusted:
3545
return False
@@ -43,6 +53,13 @@ async def should_do_compliance_check(session: database.AsyncSession, worker_id:
4353

4454

4555
async def run_compliance_check(websocket: fastapi.WebSocket, worker_id: str, worker_config: inference.WorkerConfig):
56+
"""
57+
Run a compliance check for the given worker:
58+
- Find a suitable compliance check assistant message
59+
- Task the worker with generating a response with the same context
60+
- Compare the respons against the existing completed message
61+
- Update the database with the outcome
62+
"""
4663
async with deps.manual_create_session() as session:
4764
try:
4865
worker = await worker_utils.get_worker(worker_id, session)

inference/server/oasst_inference_server/database.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ async def get_async_session(autoflush=True):
7575

7676

7777
def alembic_upgrade(connection):
78+
"""Upgrades database schema based on Alembic migration scripts."""
7879
alembic_ini_path = Path(__file__).parent.parent / "alembic.ini"
7980
alembic_cfg = alembic.config.Config(str(alembic_ini_path))
8081
alembic_cfg.set_main_option("sqlalchemy.url", settings.database_uri)

0 commit comments

Comments
 (0)