Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .env
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,6 @@ MLFLOW_ARTIFACT_DESTINATION=./mlruns

# this path is relative to where jupyter is started
MODEL_SECRETS_PATH=./config/secrets.toml

# Used by the mock vllm server to authenticate requests
VLLM_API_KEY=changeme
10 changes: 10 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@

Develop new evaluators / annotators.

## ⚠️ Content warning

The sample datasets provided in the [`flightpaths/data`](https://github.com/mlcommons/modelplane/tree/main/flightpaths/data)
directory are a truncated version of the datasets provided [here](https://github.com/mlcommons/ailuminate).
These data come with the following warning:

>This dataset was created to elicit hazardous responses. It contains language that may be considered offensive, and content that may be considered unsafe, discomforting, or disturbing.
>Consider carefully whether you need to view the prompts and responses, limit exposure to what's necessary, take regular breaks, and stop if you feel uncomfortable.
>For more information on the risks, see [this literature review](https://www.zevohealth.com/wp-content/uploads/2024/07/lit_review_IN-1.pdf) on vicarious trauma.

## Get Started

You must have docker installed on your system. The
Expand Down
2 changes: 2 additions & 0 deletions docker-compose.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ services:
USE_PRIVATE_MODELBENCH: ${USE_PRIVATE_MODELBENCH}
JUPYTER_TOKEN: ${JUPYTER_TOKEN}
GIT_PYTHON_REFRESH: ${GIT_PYTHON_REFRESH}
VLLM_API_KEY: ${VLLM_API_KEY}
# Below env needed for dvc (via git) support (backed by GCP)
# SSH_AUTH_SOCK: /ssh-agent
# GOOGLE_APPLICATION_CREDENTIALS: /creds/gcp-key.json
Expand All @@ -86,6 +87,7 @@ services:
VLLM_MODEL: mlc/not-real-model
VLLM_HOST: 0.0.0.0
VLLM_PORT: 8001
VLLM_API_KEY: ${VLLM_API_KEY}
ports:
- "8001:8001"
healthcheck:
Expand Down
14 changes: 14 additions & 0 deletions flightpaths/config/secrets.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Edit this file to add your secrets.

# This is an example of how to define a secret.
# The config is saying that within scope "vllm" we have a
# key named "api_key" that we are setting to value "changeme".
[vllm]
api_key = "changeme"

# Here are some commonly needed keys you can uncomment and use.
[together]
# api_key = "fake key"

[perspective_api]
# api_key = "<your key here>"
28 changes: 21 additions & 7 deletions flightpaths/vLLM Annotator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"from modelplane.runways import responder, annotator"
]
},
Expand Down Expand Up @@ -62,7 +64,8 @@
"\n",
"vllm_host = \"http://vllm:8001/v1\"\n",
"vllm_model = \"mlc/not-real-model\"\n",
"vllm_annotator_uid = \"vllm_dummy\""
"vllm_annotator_uid = \"vllm_dummy\"\n",
"vllm_api_key = os.getenv(\"VLLM_API_KEY\", \"changeme\")"
]
},
{
Expand Down Expand Up @@ -125,27 +128,37 @@
"source": [
"from typing import Dict\n",
"\n",
"from openai import OpenAI\n",
"from pydantic import BaseModel\n",
"\n",
"from modelgauge.annotator import CompletionAnnotator\n",
"from modelgauge.safety_model_response import SafetyModelResponse\n",
"from modelgauge.single_turn_prompt_response import TestItem\n",
"from modelgauge.secret_values import RequiredSecret, SecretDescription\n",
"from modelgauge.sut import SUTResponse\n",
"from modelgauge.tests.safe_v1 import BaseSafeTestVersion1\n",
"from openai import OpenAI\n",
"from pydantic import BaseModel\n",
"\n",
"from modelplane.utils import vllm\n",
"\n",
"class VLLMSecret(RequiredSecret):\n",
" @classmethod\n",
" def description(cls) -> SecretDescription:\n",
" return SecretDescription(\n",
" scope=\"vllm\",\n",
" key=\"api_key\",\n",
" instructions=\"The key value should be set to match what your VLLM server expects.\",\n",
" )\n",
"\n",
"\n",
"class AnnotationRequest(BaseModel):\n",
" prompts: Dict[str, str]\n",
"\n",
"\n",
"class VLLMAnnotator(CompletionAnnotator[SafetyModelResponse]):\n",
" def __init__(self, uid: str, base_url: str, model_name: str):\n",
" def __init__(self, uid: str, base_url: str, secret: VLLMSecret, model_name: str):\n",
" super().__init__(uid=uid)\n",
" self.base_url = base_url\n",
" self.model_name = model_name\n",
" self.client = vllm.Client(base_url=self.base_url)\n",
" self.client = OpenAI(base_url=self.base_url, api_key=secret.value)\n",
"\n",
" def translate_request(self, test_item: TestItem, response: SUTResponse) -> AnnotationRequest:\n",
" \"\"\"Convert the test item and SUT response into an annotation request.\n",
Expand Down Expand Up @@ -200,9 +213,10 @@
"outputs": [],
"source": [
"from modelgauge.annotator_registry import ANNOTATORS\n",
"from modelgauge.secret_values import InjectSecret\n",
"\n",
"\n",
"ANNOTATORS.register(VLLMAnnotator, vllm_annotator_uid, vllm_host, vllm_model)"
"ANNOTATORS.register(VLLMAnnotator, vllm_annotator_uid, vllm_host, InjectSecret(VLLMSecret), vllm_model)"
]
},
{
Expand Down
12 changes: 0 additions & 12 deletions src/modelplane/utils/vllm.py

This file was deleted.

11 changes: 9 additions & 2 deletions tests/notebooks/mock_vllm_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@
import os
import random

from fastapi import FastAPI
from fastapi import FastAPI, Header, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel

app = FastAPI()
allowed_model = os.getenv("VLLM_MODEL", "mlc/not-real-model")
api_key = os.getenv("VLLM_API_KEY")
if not api_key:
raise ValueError(
"VLLM_API_KEY environment variable must be set for the mock server."
)


class CompletionRequest(BaseModel):
Expand All @@ -19,7 +24,9 @@ class CompletionRequest(BaseModel):


@app.post("/v1/completions")
async def completions(req: CompletionRequest):
async def completions(req: CompletionRequest, authorization: str = Header(None)):
if authorization != f"Bearer {api_key}":
raise HTTPException(status_code=401, detail="Invalid API key.")
result_text = random.choice(["safe", "unsafe"])
return {
"id": "cmpl-mock",
Expand Down