44from fastapi import APIRouter
55from litellm .types .utils import ModelResponse
66from tlm .chat import achat , acost
7- from tlm .config .types import ModelProvider
7+ from tlm .config .types import ModelProvider , TLMOptions
88from tlm .types import Message , TokenCost , TLMScore
99
1010from src .schemas .chat import (
@@ -111,6 +111,7 @@ async def _handle_request(
111111 model_provider = model_provider ,
112112 custom_eval_config = input .options .custom_eval_criteria ,
113113 completion = completion ,
114+ tlm_options = construct_tlm_options (input ),
114115 )
115116
116117 # TODO:properly handle multiple completions
@@ -133,6 +134,12 @@ async def _handle_request(
133134 return result
134135
135136
137+ def construct_tlm_options (input : ChatInput ) -> TLMOptions :
138+ return TLMOptions (
139+ num_consistency_samples = input .options .num_consistency_samples ,
140+ )
141+
142+
136143def maybe_add_log (tlm_score : TLMScore , log : list [ChatLogField ]) -> dict [str , Any ]:
137144 log_dict = {}
138145 if ChatLogField .PER_FIELD_SCORE in log :
@@ -150,8 +157,12 @@ async def chat_costs(
150157 input : ChatInput ,
151158) -> TokenCost :
152159 reference_config = get_reference_config (input )
153- messages = [Message .model_validate (message .model_dump ()) for message in input .messages ]
154- model_provider = input .options .model_provider or model_registry .get (input .options .model )
160+ messages = [
161+ Message .model_validate (message .model_dump ()) for message in input .messages
162+ ]
163+ model_provider = input .options .model_provider or model_registry .get (
164+ input .options .model
165+ )
155166
156167 return await acost (
157168 messages ,
0 commit comments