Skip to content

Commit b1b1a3e

Browse files
authored
Converting bad ToolRequestMessage from agent LLM into MalformedMessageError (#302)
1 parent 676ab98 commit b1b1a3e

File tree

2 files changed

+46
-12
lines changed

2 files changed

+46
-12
lines changed

src/aviary/tools/utils.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from functools import partial
44
from typing import TYPE_CHECKING, Any, ClassVar, cast
55

6-
from pydantic import BaseModel, Field
6+
from pydantic import BaseModel, Field, ValidationError
77

88
from aviary.message import MalformedMessageError, Message
99

@@ -103,24 +103,32 @@ async def __call__(
103103

104104
if (num_choices := len(model_response.choices)) != 1:
105105
raise MalformedMessageError(
106-
f"Expected one choice in LiteLLM model response, got {num_choices}"
106+
f"Expected one choice in model response, got {num_choices}"
107107
f" choices, full response was {model_response}."
108108
)
109109
choice = model_response.choices[0]
110110
if choice.finish_reason not in expected_finish_reason:
111111
raise MalformedMessageError(
112-
f"Expected a finish reason in {expected_finish_reason} in LiteLLM"
112+
f"Expected a finish reason in {expected_finish_reason} in"
113113
f" model response, got finish reason {choice.finish_reason!r}, full"
114-
f" response was {model_response} and tool choice was {tool_choice}."
114+
f" response was {model_response} and tool choice was {tool_choice!r}."
115115
)
116116
usage = model_response.usage
117-
selection = ToolRequestMessage(
118-
**choice.message.model_dump(),
119-
info={
120-
"usage": (usage.prompt_tokens, usage.completion_tokens),
121-
"model": self._model_name,
122-
},
123-
)
117+
try:
118+
selection = ToolRequestMessage(
119+
**choice.message.model_dump(),
120+
info={
121+
"usage": (usage.prompt_tokens, usage.completion_tokens),
122+
"model": self._model_name,
123+
},
124+
)
125+
except ValidationError as exc:
126+
raise MalformedMessageError(
127+
f"Failed to convert model response's message {choice.message}"
128+
f" into a tool request message."
129+
f" Got finish reason {choice.finish_reason!r}, full"
130+
f" response was {model_response} and tool choice was {tool_choice!r}."
131+
) from exc
124132
if self._ledger is not None:
125133
self._ledger.messages.append(selection)
126134
return selection

tests/test_envs.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import litellm
1111
import pytest
1212
from httpx import ASGITransport, AsyncClient
13-
from pydantic import BaseModel
13+
from pydantic import BaseModel, ValidationError
1414
from pytest_subtests import SubTests
1515

1616
from aviary.core import (
@@ -30,6 +30,7 @@
3030
ToolSelectorLedger,
3131
)
3232
from aviary.dataset_server import TaskDatasetServer
33+
from aviary.message import MalformedMessageError
3334
from aviary.tools import FunctionInfo, Messages
3435
from tests import CILLMModelNames
3536
from tests.conftest import VCR_DEFAULT_MATCH_ON
@@ -483,6 +484,31 @@ async def inner1() -> None: # noqa: RUF029
483484
"Expected sub-exceptions to be displayed"
484485
)
485486

487+
@pytest.mark.asyncio
488+
async def test_tool_selector_bad_agent_llm_response(
489+
self, dummy_env: DummyEnv
490+
) -> None:
491+
obs, tools = await dummy_env.reset()
492+
493+
async def stub_acompletion(*_, **__) -> litellm.ModelResponse: # noqa: RUF029
494+
return litellm.ModelResponse(
495+
choices=[
496+
litellm.Choices(
497+
# Malformatted because it contains null tool calls
498+
message=ToolRequestMessage().model_dump() | {"tool_calls": None}
499+
)
500+
]
501+
)
502+
503+
selector = ToolSelector("stub", acompletion=stub_acompletion)
504+
with pytest.raises(
505+
MalformedMessageError, match="tool request message"
506+
) as exc_info:
507+
await selector(obs, tools=tools)
508+
assert isinstance(exc_info.value.__cause__, ValidationError), (
509+
"We should be able to retrieve the original validation error"
510+
)
511+
486512
@pytest.mark.vcr(match_on=[*VCR_DEFAULT_MATCH_ON, "body"])
487513
@pytest.mark.parametrize("model_name", [CILLMModelNames.OPENAI.value])
488514
@pytest.mark.asyncio

0 commit comments

Comments
 (0)