Skip to content

Commit 325c4ea

Browse files
committed
fix(vertexai): guard tool call args against non-finite numbers
1 parent 5f73ecd commit 325c4ea

File tree

2 files changed

+130
-3
lines changed

2 files changed

+130
-3
lines changed

libs/vertexai/langchain_google_vertexai/chat_models.py

Lines changed: 85 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from __future__ import annotations # noqa
44
import ast
55
import base64
6+
import math
67
from functools import cached_property
78
import json
89
import logging
@@ -24,9 +25,15 @@
2425
TypedDict,
2526
overload,
2627
)
27-
from collections.abc import AsyncIterator, Iterator, Sequence
28+
from collections.abc import AsyncIterator, Iterator, Sequence, Mapping
2829

2930
import proto # type: ignore[import-untyped]
31+
from google.protobuf.json_format import SerializeToJsonError # type: ignore[import-untyped]
32+
from google.protobuf.struct_pb2 import ( # type: ignore[import-untyped]
33+
ListValue,
34+
Struct,
35+
Value,
36+
)
3037

3138
from langchain_core.callbacks import (
3239
AsyncCallbackManagerForLLMRun,
@@ -614,6 +621,80 @@ def _append_to_content(
614621
raise TypeError(msg)
615622

616623

624+
def _json_safe_number(value: float) -> Union[str, float]:
625+
if math.isnan(value):
626+
return "NaN"
627+
if math.isinf(value):
628+
return "Infinity" if value > 0 else "-Infinity"
629+
return value
630+
631+
632+
def _struct_value_to_jsonable(value: Value) -> Any:
633+
kind = value.WhichOneof("kind")
634+
if kind == "number_value":
635+
return _json_safe_number(value.number_value)
636+
if kind == "string_value":
637+
return value.string_value
638+
if kind == "bool_value":
639+
return bool(value.bool_value)
640+
if kind == "null_value":
641+
return None
642+
if kind == "struct_value":
643+
return _struct_to_jsonable(value.struct_value)
644+
if kind == "list_value":
645+
return [_struct_value_to_jsonable(item) for item in value.list_value.values]
646+
return None
647+
648+
649+
def _struct_to_jsonable(struct: Struct) -> dict[str, Any]:
650+
return {key: _struct_value_to_jsonable(val) for key, val in struct.fields.items()}
651+
652+
653+
def _make_jsonable(value: Any) -> Any:
654+
if isinstance(value, Value):
655+
return _struct_value_to_jsonable(value)
656+
if isinstance(value, Struct):
657+
return _struct_to_jsonable(value)
658+
if isinstance(value, ListValue):
659+
return [_struct_value_to_jsonable(item) for item in value.values]
660+
if isinstance(value, Mapping):
661+
return {str(key): _make_jsonable(val) for key, val in value.items()}
662+
if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)):
663+
return [_make_jsonable(item) for item in value]
664+
if isinstance(value, float):
665+
return _json_safe_number(value)
666+
return value
667+
668+
669+
def _coerce_function_call_args(function_call: FunctionCall) -> dict[str, Any]:
670+
try:
671+
fc_dict = proto.Message.to_dict(function_call)
672+
except SerializeToJsonError:
673+
fc_dict = {}
674+
except TypeError:
675+
fc_dict = {}
676+
677+
args_dict: Any = fc_dict.get("args") if isinstance(fc_dict, dict) else None
678+
if isinstance(args_dict, dict):
679+
return dict(args_dict)
680+
681+
struct_args = getattr(function_call, "args", None)
682+
if isinstance(struct_args, Struct):
683+
return _struct_to_jsonable(struct_args)
684+
if isinstance(struct_args, Mapping):
685+
return {str(key): _make_jsonable(val) for key, val in struct_args.items()}
686+
687+
if struct_args is not None:
688+
try:
689+
fallback_dict = proto.Message.to_dict(struct_args)
690+
except Exception:
691+
fallback_dict = {}
692+
if isinstance(fallback_dict, dict):
693+
return fallback_dict
694+
695+
return {}
696+
697+
617698
@overload
618699
def _parse_response_candidate(
619700
response_candidate: Candidate, streaming: Literal[False] = False
@@ -657,9 +738,10 @@ def _parse_response_candidate(
657738
# but in general the full set of function calls is stored in tool_calls.
658739
function_call = {"name": part.function_call.name}
659740
# dump to match other function calling llm for now
660-
function_call_args_dict = proto.Message.to_dict(part.function_call)["args"]
741+
function_call_args_dict = _coerce_function_call_args(part.function_call)
661742
function_call["arguments"] = json.dumps(
662-
{k: function_call_args_dict[k] for k in function_call_args_dict}
743+
function_call_args_dict,
744+
allow_nan=False,
663745
)
664746
additional_kwargs["function_call"] = function_call
665747

libs/vertexai/tests/unit_tests/test_chat_models.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1016,6 +1016,51 @@ def test_default_params_gemini() -> None:
10161016
},
10171017
),
10181018
),
1019+
(
1020+
Candidate(
1021+
content=Content(
1022+
role="model",
1023+
parts=[
1024+
Part(
1025+
function_call=FunctionCall(
1026+
name="handle_numbers",
1027+
args={
1028+
"positive": float("inf"),
1029+
"negative": float("-inf"),
1030+
"not_a_number": float("nan"),
1031+
},
1032+
),
1033+
),
1034+
],
1035+
)
1036+
),
1037+
AIMessage(
1038+
content="",
1039+
tool_calls=[
1040+
create_tool_call(
1041+
name="handle_numbers",
1042+
args={
1043+
"positive": "Infinity",
1044+
"negative": "-Infinity",
1045+
"not_a_number": "NaN",
1046+
},
1047+
id="00000000-0000-0000-0000-00000000000",
1048+
),
1049+
],
1050+
additional_kwargs={
1051+
"function_call": {
1052+
"name": "handle_numbers",
1053+
"arguments": json.dumps(
1054+
{
1055+
"positive": "Infinity",
1056+
"negative": "-Infinity",
1057+
"not_a_number": "NaN",
1058+
}
1059+
),
1060+
}
1061+
},
1062+
),
1063+
),
10191064
],
10201065
)
10211066
def test_parse_response_candidate(raw_candidate, expected) -> None:

0 commit comments

Comments
 (0)