|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import copy |
15 | 16 | import functools |
| 17 | +import inspect |
16 | 18 | import json |
17 | 19 | import logging |
18 | 20 | import os |
|
28 | 30 | ContentListUnionDict, |
29 | 31 | ContentUnion, |
30 | 32 | ContentUnionDict, |
| 33 | + GenerateContentConfig, |
31 | 34 | GenerateContentConfigOrDict, |
32 | 35 | GenerateContentResponse, |
| 36 | + ToolListUnionDict, |
33 | 37 | ) |
34 | 38 |
|
35 | 39 | from opentelemetry import trace |
@@ -495,6 +499,81 @@ def _record_duration_metric(self): |
495 | 499 | ) |
496 | 500 |
|
497 | 501 |
|
| 502 | +def _record_tool_call_args( |
| 503 | + otel_wrapper: OTelWrapper, |
| 504 | + tool_name: str, |
| 505 | + args: list[Any], |
| 506 | + kwargs: dict[Any, Any]): |
| 507 | + attributes = { |
| 508 | + code_attributes.CODE_FUNCTION_NAME: tool_name, |
| 509 | + } |
| 510 | + body = { |
| 511 | + "positional_arguments": [_to_dict(arg) for arg in args], |
| 512 | + "keyword_arguments": dict([(key, _to_dict(value)) for (key, value) in kwargs.items()]) |
| 513 | + } |
| 514 | + otel_wrapper.log_tool_call(attributes, body) |
| 515 | + |
| 516 | + |
| 517 | +def _record_tool_call_result( |
| 518 | + otel_wrapper: OTelWrapper, |
| 519 | + tool_name: str, |
| 520 | + args: list[Any], |
| 521 | + kwargs: dict[Any, Any], |
| 522 | + result: Any): |
| 523 | + attributes = { |
| 524 | + code_attributes.CODE_FUNCTION_NAME: tool_name, |
| 525 | + } |
| 526 | + body = { |
| 527 | + "positional_arguments": [_to_dict(arg) for arg in args], |
| 528 | + "keyword_arguments": dict([(key, _to_dict(value)) for (key, value) in kwargs.items()]), |
| 529 | + "return_value": _to_dict(result), |
| 530 | + } |
| 531 | + otel_wrapper.log_tool_call_result(attributes, body) |
| 532 | + |
| 533 | + |
| 534 | +def _wrapped_tool(otel_wrapper: OTelWrapper, tool: ToolListUnionDict): |
| 535 | + if not callable(tool): |
| 536 | + return tool |
| 537 | + if inspect.iscoroutinefunction(tool): |
| 538 | + return tool |
| 539 | + tool_name = tool.__name__ |
| 540 | + should_record_contents = flags.is_content_recording_enabled() |
| 541 | + @functools.wraps(tool) |
| 542 | + def wrapped_tool(*args, **kwargs): |
| 543 | + with otel_wrapper.start_as_current_span( |
| 544 | + f'tool_call {tool_name}', |
| 545 | + attributes={ |
| 546 | + code_attributes.CODE_FUNCTION_NAME: tool.__name__, |
| 547 | + }): |
| 548 | + if should_record_contents: |
| 549 | + _record_tool_call_args(otel_wrapper, tool.__name__, args, kwargs) |
| 550 | + result = tool(*args, **kwargs) |
| 551 | + if should_record_contents: |
| 552 | + _record_tool_call_result(otel_wrapper, tool.__name__, args, kwargs, result) |
| 553 | + return result |
| 554 | + return wrapped_tool |
| 555 | + |
| 556 | + |
| 557 | +def _wrapped_config_with_tools( |
| 558 | + otel_wrapper: OTelWrapper, |
| 559 | + config: GenerateContentConfig) -> GenerateContentConfig: |
| 560 | + result = copy.copy(config) |
| 561 | + result.tool = [_wrapped_tool(otel_wrapper, tool) for tool in config.tools] |
| 562 | + return result |
| 563 | + |
| 564 | + |
| 565 | +def _wrapped_config( |
| 566 | + otel_wrapper: OTelWrapper, |
| 567 | + config: Optional[GenerateContentConfigOrDict]) -> Optional[GenerateContentConfig]: |
| 568 | + if config is None: |
| 569 | + return None |
| 570 | + if isinstance(config, dict): |
| 571 | + config = GenerateContentConfig(config) |
| 572 | + if not config.tools: |
| 573 | + return config |
| 574 | + return _wrapped_config_with_tools(otel_wrapper, config) |
| 575 | + |
| 576 | + |
498 | 577 | def _create_instrumented_generate_content( |
499 | 578 | snapshot: _MethodsSnapshot, otel_wrapper: OTelWrapper |
500 | 579 | ): |
|
0 commit comments