Skip to content

Commit 5cee348

Browse files
authored
Add num_consistency_samples arg (#31)
1 parent 02f15c3 commit 5cee348

File tree

3 files changed

+16
-4
lines changed

3 files changed

+16
-4
lines changed

services/chat-backend/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,4 @@ show_missing = true
5757
default-groups = ["dev"]
5858

5959
[tlm-core]
60-
version = "v0.0.47"
60+
version = "v0.0.48"

services/chat-backend/src/routers/chat.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from fastapi import APIRouter
55
from litellm.types.utils import ModelResponse
66
from tlm.chat import achat, acost
7-
from tlm.config.types import ModelProvider
7+
from tlm.config.types import ModelProvider, TLMOptions
88
from tlm.types import Message, TokenCost, TLMScore
99

1010
from src.schemas.chat import (
@@ -111,6 +111,7 @@ async def _handle_request(
111111
model_provider=model_provider,
112112
custom_eval_config=input.options.custom_eval_criteria,
113113
completion=completion,
114+
tlm_options=construct_tlm_options(input),
114115
)
115116

116117
# TODO:properly handle multiple completions
@@ -133,6 +134,12 @@ async def _handle_request(
133134
return result
134135

135136

137+
def construct_tlm_options(input: ChatInput) -> TLMOptions:
138+
return TLMOptions(
139+
num_consistency_samples=input.options.num_consistency_samples,
140+
)
141+
142+
136143
def maybe_add_log(tlm_score: TLMScore, log: list[ChatLogField]) -> dict[str, Any]:
137144
log_dict = {}
138145
if ChatLogField.PER_FIELD_SCORE in log:
@@ -150,8 +157,12 @@ async def chat_costs(
150157
input: ChatInput,
151158
) -> TokenCost:
152159
reference_config = get_reference_config(input)
153-
messages = [Message.model_validate(message.model_dump()) for message in input.messages]
154-
model_provider = input.options.model_provider or model_registry.get(input.options.model)
160+
messages = [
161+
Message.model_validate(message.model_dump()) for message in input.messages
162+
]
163+
model_provider = input.options.model_provider or model_registry.get(
164+
input.options.model
165+
)
155166

156167
return await acost(
157168
messages,

services/chat-backend/src/schemas/chat/input.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class ChatTlmInput(BaseModel):
5555

5656
model_provider: ModelProviderInput | None = None
5757
custom_eval_criteria: list[CustomEvalConfig] = Field(default_factory=list)
58+
num_consistency_samples: int | None = None
5859

5960
log: list[ChatLogField] = Field(
6061
default_factory=lambda: [

0 commit comments

Comments
 (0)