Skip to content

Commit ee7c373

Browse files
authored
Feature/prompt management (#200)
* [feat] prompt management * [feat] testing * [feat] only one active prompt
1 parent 91f232e commit ee7c373

File tree

7 files changed

+337
-0
lines changed

7 files changed

+337
-0
lines changed

libs/tracker/llmstudio_tracker/prompt_manager/__init__.py

Whitespace-only changes.
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
from llmstudio_tracker.prompt_manager import models, schemas
2+
from sqlalchemy.orm import Session
3+
4+
5+
def get_prompt_by_name_model_provider(
6+
db: Session, name: str, model: str, provider: str
7+
):
8+
return (
9+
db.query(models.PromptDefault)
10+
.filter(
11+
models.PromptDefault.name == name,
12+
models.PromptDefault.model == model,
13+
models.PromptDefault.provider == provider,
14+
models.PromptDefault.is_active == True,
15+
)
16+
.order_by(models.PromptDefault.version.desc())
17+
.first()
18+
)
19+
20+
21+
def get_prompt_by_id(db: Session, prompt_id: str):
22+
return (
23+
db.query(models.PromptDefault)
24+
.filter(models.PromptDefault.prompt_id == prompt_id)
25+
.first()
26+
)
27+
28+
29+
def get_prompt(
30+
db: Session,
31+
prompt_id: str = None,
32+
name: str = None,
33+
model: str = None,
34+
provider: str = None,
35+
):
36+
if prompt_id:
37+
return get_prompt_by_id(db, prompt_id)
38+
else:
39+
return get_prompt_by_name_model_provider(db, name, model, provider)
40+
41+
42+
def add_prompt(db: Session, prompt: schemas.PromptDefault):
43+
44+
prompt_created = models.PromptDefault.create_with_incremental_version(
45+
db,
46+
config=prompt.config,
47+
prompt=prompt.prompt,
48+
is_active=prompt.is_active,
49+
name=prompt.name,
50+
label=prompt.label,
51+
model=prompt.model,
52+
provider=prompt.provider,
53+
)
54+
db.add(prompt_created)
55+
db.commit()
56+
db.refresh(prompt_created)
57+
return prompt_created
58+
59+
60+
def update_prompt(db: Session, prompt: schemas.PromptDefault):
61+
if prompt.prompt_id:
62+
existing_prompt = get_prompt_by_id(db, prompt.prompt_id)
63+
else:
64+
existing_prompt = get_prompt_by_name_model_provider(
65+
db, prompt.name, prompt.model, prompt.provider
66+
)
67+
68+
existing_prompt.config = prompt.config
69+
existing_prompt.prompt = prompt.prompt
70+
existing_prompt.is_active = prompt.is_active
71+
existing_prompt.name = prompt.name
72+
existing_prompt.model = prompt.model
73+
existing_prompt.provider = prompt.provider
74+
existing_prompt.version = prompt.version
75+
existing_prompt.label = prompt.label
76+
77+
db.commit()
78+
db.refresh(existing_prompt)
79+
return existing_prompt
80+
81+
82+
def delete_prompt(db: Session, prompt: schemas.PromptDefault):
83+
if prompt.prompt_id:
84+
existing_prompt = get_prompt_by_id(db, prompt.prompt_id)
85+
else:
86+
existing_prompt = get_prompt_by_name_model_provider(
87+
db, prompt.name, prompt.model, prompt.provider
88+
)
89+
90+
db.delete(existing_prompt)
91+
db.commit()
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
from fastapi import APIRouter, Depends
2+
from llmstudio_tracker.database import engine, get_db
3+
from llmstudio_tracker.prompt_manager import crud, models, schemas
4+
from sqlalchemy.orm import Session
5+
6+
models.Base.metadata.create_all(bind=engine)
7+
8+
9+
class PromptsRoutes:
10+
def __init__(self, router: APIRouter):
11+
self.router = router
12+
self.define_routes()
13+
14+
def define_routes(self):
15+
self.router.post(
16+
"/add/prompt",
17+
response_model=schemas.PromptDefault,
18+
)(self.add_prompt)
19+
20+
self.router.get("/get/prompt", response_model=schemas.PromptDefault)(
21+
self.get_prompt
22+
)
23+
24+
self.router.patch("/update/prompt", response_model=schemas.PromptDefault)(
25+
self.update_prompt
26+
)
27+
28+
self.router.delete("/delete/prompt")(self.delete_prompt)
29+
30+
async def add_prompt(
31+
self, prompt: schemas.PromptDefault, db: Session = Depends(get_db)
32+
):
33+
return crud.add_prompt(db=db, prompt=prompt)
34+
35+
async def update_prompt(
36+
self, prompt: schemas.PromptDefault, db: Session = Depends(get_db)
37+
):
38+
return crud.update_prompt(db, prompt)
39+
40+
async def get_prompt(
41+
self,
42+
prompt_info: schemas.PromptInfo,
43+
db: Session = Depends(get_db),
44+
):
45+
return crud.get_prompt(
46+
db,
47+
prompt_id=prompt_info.prompt_id,
48+
name=prompt_info.name,
49+
model=prompt_info.model,
50+
provider=prompt_info.provider,
51+
)
52+
53+
async def delete_prompt(
54+
self, prompt: schemas.PromptDefault, db: Session = Depends(get_db)
55+
):
56+
return crud.delete_prompt(db, prompt)
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
import json
2+
3+
import requests
4+
from llmstudio_tracker.prompt_manager.schemas import PromptDefault
5+
from llmstudio_tracker.tracker import TrackingConfig
6+
7+
8+
class PromptManager:
9+
def __init__(self, tracking_config: TrackingConfig):
10+
self.tracking_url = tracking_config.url
11+
self._session = requests.Session()
12+
13+
def add_prompt(self, prompt: PromptDefault):
14+
req = self._session.post(
15+
f"{self.tracking_url}/api/tracking/add/prompt",
16+
headers={"accept": "application/json", "Content-Type": "application/json"},
17+
data=prompt.model_dump_json(),
18+
timeout=100,
19+
)
20+
return req
21+
22+
def delete_prompt(self, prompt: PromptDefault):
23+
req = self._session.delete(
24+
f"{self.tracking_url}/api/tracking/delete/prompt",
25+
headers={"accept": "application/json", "Content-Type": "application/json"},
26+
data=prompt.model_dump_json(),
27+
timeout=100,
28+
)
29+
return req
30+
31+
def update_prompt(self, prompt: PromptDefault):
32+
req = self._session.patch(
33+
f"{self.tracking_url}/api/tracking/update/prompt",
34+
headers={"accept": "application/json", "Content-Type": "application/json"},
35+
data=prompt.model_dump_json(),
36+
timeout=100,
37+
)
38+
return req
39+
40+
def get_prompt(
41+
self,
42+
prompt_id: str = None,
43+
name: str = None,
44+
model: str = None,
45+
provider: str = None,
46+
):
47+
48+
data = {
49+
"prompt_id": prompt_id,
50+
"name": name,
51+
"model": model,
52+
"provider": provider,
53+
}
54+
55+
req = self._session.get(
56+
f"{self.tracking_url}/api/tracking/get/prompt",
57+
headers={"accept": "application/json", "Content-Type": "application/json"},
58+
timeout=100,
59+
data=json.dumps(data),
60+
)
61+
return req
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import uuid
2+
from datetime import datetime, timezone
3+
4+
from llmstudio_tracker.config import DB_TYPE
5+
from llmstudio_tracker.database import Base
6+
from llmstudio_tracker.db_utils import JSONEncodedDict
7+
from sqlalchemy import (
8+
JSON,
9+
Boolean,
10+
Column,
11+
DateTime,
12+
Integer,
13+
String,
14+
UniqueConstraint,
15+
event,
16+
func,
17+
)
18+
from sqlalchemy.orm import Session
19+
20+
21+
class PromptDefault(Base):
22+
__tablename__ = "prompts"
23+
24+
if DB_TYPE == "bigquery":
25+
prompt_id = Column(
26+
String,
27+
primary_key=True,
28+
default=lambda: str(uuid.uuid4()),
29+
)
30+
config = Column(JSONEncodedDict, nullable=True)
31+
else:
32+
prompt_id = Column(
33+
String, primary_key=True, default=lambda: str(uuid.uuid4())
34+
) # Generate UUID as a string
35+
config = Column(JSON, nullable=True)
36+
37+
prompt = Column(String)
38+
is_active = Column(Boolean, default=False)
39+
name = Column(String, nullable=False)
40+
model = Column(String, nullable=False)
41+
provider = Column(String, nullable=False)
42+
version = Column(Integer, nullable=False)
43+
label = Column(String)
44+
updated_at = Column(
45+
DateTime(timezone=True),
46+
onupdate=lambda: datetime.now(timezone.utc),
47+
default=lambda: datetime.now(timezone.utc),
48+
)
49+
created_at = Column(
50+
DateTime(timezone=True), default=lambda: datetime.now(timezone.utc)
51+
)
52+
53+
__table_args__ = (
54+
UniqueConstraint(
55+
"name", "provider", "model", "version", name="uq_prompt_version"
56+
),
57+
)
58+
59+
@staticmethod
60+
def get_next_version(session, name, model, provider):
61+
"""
62+
Get the next version number for a combination of name, model, and provider.
63+
"""
64+
max_version = (
65+
session.query(func.max(PromptDefault.version))
66+
.filter_by(name=name, model=model, provider=provider)
67+
.scalar()
68+
)
69+
return (max_version or 0) + 1
70+
71+
@classmethod
72+
def create_with_incremental_version(cls, session, **kwargs):
73+
"""
74+
Create a new PromptDefault entry with an incremental version.
75+
"""
76+
name = kwargs.get("name")
77+
model = kwargs.get("model")
78+
provider = kwargs.get("provider")
79+
if not all([name, model, provider]):
80+
raise ValueError("name, model, and provider must be provided")
81+
82+
kwargs["version"] = cls.get_next_version(session, name, model, provider)
83+
84+
instance = cls(**kwargs)
85+
session.add(instance)
86+
return instance
87+
88+
@event.listens_for(Session, "before_flush")
89+
def ensure_single_active_prompt(session, flush_context, instances):
90+
"""
91+
Ensures only one PromptDefault entry per (name, model, provider) can have is_active=True.
92+
If a new entry is set as is_active=True, deactivate others in the same group.
93+
"""
94+
for instance in session.new.union(session.dirty):
95+
if isinstance(instance, PromptDefault) and instance.is_active:
96+
session.query(PromptDefault).filter(
97+
PromptDefault.name == instance.name,
98+
PromptDefault.model == instance.model,
99+
PromptDefault.provider == instance.provider,
100+
PromptDefault.is_active == True,
101+
PromptDefault.prompt_id != instance.prompt_id,
102+
).update({"is_active": False}, synchronize_session="fetch")
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
from datetime import datetime
2+
from typing import Dict, Optional
3+
4+
from pydantic import BaseModel
5+
6+
7+
class PromptInfo(BaseModel):
8+
prompt_id: Optional[str] = None
9+
name: Optional[str] = None
10+
model: Optional[str] = None
11+
provider: Optional[str] = None
12+
13+
14+
class PromptDefault(BaseModel):
15+
prompt_id: Optional[str] = None
16+
config: Optional[Dict] = {}
17+
prompt: str
18+
is_active: Optional[bool] = None
19+
name: str
20+
version: Optional[int] = None
21+
label: Optional[str] = "production"
22+
model: str
23+
provider: str
24+
updated_at: Optional[datetime] = None
25+
created_at: Optional[datetime] = None

libs/tracker/llmstudio_tracker/server.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from fastapi.middleware.cors import CORSMiddleware
77
from llmstudio_tracker.config import TRACKING_HOST, TRACKING_PORT
88
from llmstudio_tracker.logs.endpoints import LogsRoutes
9+
from llmstudio_tracker.prompt_manager.endpoints import PromptsRoutes
910
from llmstudio_tracker.session.endpoints import SessionsRoutes
1011
from llmstudio_tracker.utils import get_current_version
1112

@@ -42,6 +43,7 @@ def health_check():
4243
tracking_router = APIRouter(prefix=TRACKING_BASE_ENDPOINT)
4344
LogsRoutes(tracking_router)
4445
SessionsRoutes(tracking_router)
46+
PromptsRoutes(tracking_router)
4547

4648
app.include_router(tracking_router)
4749

0 commit comments

Comments
 (0)