diff --git a/services/chat-backend/pyproject.toml b/services/chat-backend/pyproject.toml index d2eef9d..83e1040 100644 --- a/services/chat-backend/pyproject.toml +++ b/services/chat-backend/pyproject.toml @@ -56,4 +56,4 @@ show_missing = true default-groups = ["dev"] [tlm-core] -version = "v0.0.41" +version = "v0.0.42" diff --git a/services/chat-backend/src/routers/chat.py b/services/chat-backend/src/routers/chat.py index d0cc257..33a12b9 100644 --- a/services/chat-backend/src/routers/chat.py +++ b/services/chat-backend/src/routers/chat.py @@ -13,7 +13,7 @@ ) from src.services.azure import apply_azure_ad_token from src.services.chat_config import get_reference_config -from src.services.chat_log_filter import filter_log +from src.schemas.chat.input import ChatLogField from src.utils.models import registry as model_registry @@ -99,18 +99,23 @@ async def _handle_request( # TODO:properly handle multiple completions # or maybe don't support n > 1 for now - model_completion, score = completions[0] + model_completion, tlm_score = completions[0] result = { **model_completion.model_dump(), "tlm_metadata": { - "trustworthiness_score": float(score), + "trustworthiness_score": float(tlm_score.trustworthiness_score), }, } - filtered_log = filter_log(log, input.options.log) - if metadata_log := filtered_log.get(str(0), {}): - result["tlm_metadata"]["metadata"] = metadata_log + # add explanation to the result if specified + if ChatLogField.EXPLANATION in input.options.log: + result["tlm_metadata"]["log"] = {"explanation": tlm_score.explanation} + + # TODO: these are internal logs, do not expose for now + # filtered_log = filter_log(log, input.options.log) + # if metadata_log := filtered_log.get(str(0), {}): + # result["tlm_metadata"]["metadata"] = metadata_log return result diff --git a/services/chat-backend/src/schemas/chat/input.py b/services/chat-backend/src/schemas/chat/input.py index 9e27dd5..9ec0dd6 100644 --- a/services/chat-backend/src/schemas/chat/input.py +++ b/services/chat-backend/src/schemas/chat/input.py @@ -13,6 +13,8 @@ class ChatLogField(enum.Enum): """The fields to log for a chat completion.""" + EXPLANATION = "explanation" + RESULT_COMPLETION = "result_completion" RESULT_SCORE = "result_score"