|
3 | 3 | from __future__ import annotations # noqa |
4 | 4 | import ast |
5 | 5 | import base64 |
| 6 | +import math |
6 | 7 | from functools import cached_property |
7 | 8 | import json |
8 | 9 | import logging |
|
24 | 25 | TypedDict, |
25 | 26 | overload, |
26 | 27 | ) |
27 | | -from collections.abc import AsyncIterator, Iterator, Sequence |
| 28 | +from collections.abc import AsyncIterator, Iterator, Sequence, Mapping |
28 | 29 |
|
29 | 30 | import proto # type: ignore[import-untyped] |
| 31 | +from google.protobuf.json_format import SerializeToJsonError # type: ignore[import-untyped] |
| 32 | +from google.protobuf.struct_pb2 import ( # type: ignore[import-untyped] |
| 33 | + ListValue, |
| 34 | + Struct, |
| 35 | + Value, |
| 36 | +) |
30 | 37 |
|
31 | 38 | from langchain_core.callbacks import ( |
32 | 39 | AsyncCallbackManagerForLLMRun, |
@@ -614,6 +621,80 @@ def _append_to_content( |
614 | 621 | raise TypeError(msg) |
615 | 622 |
|
616 | 623 |
|
| 624 | +def _json_safe_number(value: float) -> Union[str, float]: |
| 625 | + if math.isnan(value): |
| 626 | + return "NaN" |
| 627 | + if math.isinf(value): |
| 628 | + return "Infinity" if value > 0 else "-Infinity" |
| 629 | + return value |
| 630 | + |
| 631 | + |
| 632 | +def _struct_value_to_jsonable(value: Value) -> Any: |
| 633 | + kind = value.WhichOneof("kind") |
| 634 | + if kind == "number_value": |
| 635 | + return _json_safe_number(value.number_value) |
| 636 | + if kind == "string_value": |
| 637 | + return value.string_value |
| 638 | + if kind == "bool_value": |
| 639 | + return bool(value.bool_value) |
| 640 | + if kind == "null_value": |
| 641 | + return None |
| 642 | + if kind == "struct_value": |
| 643 | + return _struct_to_jsonable(value.struct_value) |
| 644 | + if kind == "list_value": |
| 645 | + return [_struct_value_to_jsonable(item) for item in value.list_value.values] |
| 646 | + return None |
| 647 | + |
| 648 | + |
| 649 | +def _struct_to_jsonable(struct: Struct) -> dict[str, Any]: |
| 650 | + return {key: _struct_value_to_jsonable(val) for key, val in struct.fields.items()} |
| 651 | + |
| 652 | + |
| 653 | +def _make_jsonable(value: Any) -> Any: |
| 654 | + if isinstance(value, Value): |
| 655 | + return _struct_value_to_jsonable(value) |
| 656 | + if isinstance(value, Struct): |
| 657 | + return _struct_to_jsonable(value) |
| 658 | + if isinstance(value, ListValue): |
| 659 | + return [_struct_value_to_jsonable(item) for item in value.values] |
| 660 | + if isinstance(value, Mapping): |
| 661 | + return {str(key): _make_jsonable(val) for key, val in value.items()} |
| 662 | + if isinstance(value, Sequence) and not isinstance(value, (str, bytes, bytearray)): |
| 663 | + return [_make_jsonable(item) for item in value] |
| 664 | + if isinstance(value, float): |
| 665 | + return _json_safe_number(value) |
| 666 | + return value |
| 667 | + |
| 668 | + |
| 669 | +def _coerce_function_call_args(function_call: FunctionCall) -> dict[str, Any]: |
| 670 | + try: |
| 671 | + fc_dict = proto.Message.to_dict(function_call) |
| 672 | + except SerializeToJsonError: |
| 673 | + fc_dict = {} |
| 674 | + except TypeError: |
| 675 | + fc_dict = {} |
| 676 | + |
| 677 | + args_dict: Any = fc_dict.get("args") if isinstance(fc_dict, dict) else None |
| 678 | + if isinstance(args_dict, dict): |
| 679 | + return dict(args_dict) |
| 680 | + |
| 681 | + struct_args = getattr(function_call, "args", None) |
| 682 | + if isinstance(struct_args, Struct): |
| 683 | + return _struct_to_jsonable(struct_args) |
| 684 | + if isinstance(struct_args, Mapping): |
| 685 | + return {str(key): _make_jsonable(val) for key, val in struct_args.items()} |
| 686 | + |
| 687 | + if struct_args is not None: |
| 688 | + try: |
| 689 | + fallback_dict = proto.Message.to_dict(struct_args) |
| 690 | + except Exception: |
| 691 | + fallback_dict = {} |
| 692 | + if isinstance(fallback_dict, dict): |
| 693 | + return fallback_dict |
| 694 | + |
| 695 | + return {} |
| 696 | + |
| 697 | + |
617 | 698 | @overload |
618 | 699 | def _parse_response_candidate( |
619 | 700 | response_candidate: Candidate, streaming: Literal[False] = False |
@@ -657,9 +738,10 @@ def _parse_response_candidate( |
657 | 738 | # but in general the full set of function calls is stored in tool_calls. |
658 | 739 | function_call = {"name": part.function_call.name} |
659 | 740 | # dump to match other function calling llm for now |
660 | | - function_call_args_dict = proto.Message.to_dict(part.function_call)["args"] |
| 741 | + function_call_args_dict = _coerce_function_call_args(part.function_call) |
661 | 742 | function_call["arguments"] = json.dumps( |
662 | | - {k: function_call_args_dict[k] for k in function_call_args_dict} |
| 743 | + function_call_args_dict, |
| 744 | + allow_nan=False, |
663 | 745 | ) |
664 | 746 | additional_kwargs["function_call"] = function_call |
665 | 747 |
|
|
0 commit comments