22
33from collections .abc import AsyncIterator
44from dataclasses import dataclass , field
5+ from typing import Any
56
67from openai import AsyncStream
78from openai .types .chat import ChatCompletionChunk
@@ -65,6 +66,8 @@ class StreamingState:
6566 # Store accumulated thinking text and signature for Anthropic compatibility
6667 thinking_text : str = ""
6768 thinking_signature : str | None = None
69+ # Store thought signatures for Gemini function calls (indexed by tool call index)
70+ function_call_thought_signatures : dict [int , str ] = field (default_factory = dict )
6871
6972
7073class SequenceNumber :
@@ -359,6 +362,17 @@ async def handle_stream(
359362 if tc_delta .id :
360363 state .function_calls [tc_delta .index ].call_id = tc_delta .id
361364
365+ # Capture thought_signature from Gemini (provider_specific_fields)
366+ if (
367+ hasattr (tc_delta , "provider_specific_fields" )
368+ and tc_delta .provider_specific_fields
369+ ):
370+ provider_fields = tc_delta .provider_specific_fields
371+ if isinstance (provider_fields , dict ):
372+ thought_sig = provider_fields .get ("thought_signature" )
373+ if thought_sig :
374+ state .function_call_thought_signatures [tc_delta .index ] = thought_sig
375+
362376 function_call = state .function_calls [tc_delta .index ]
363377
364378 # Start streaming as soon as we have function name and call_id
@@ -483,14 +497,26 @@ async def handle_stream(
483497 if state .function_call_streaming .get (index , False ):
484498 # Function call was streamed, just send the completion event
485499 output_index = state .function_call_output_idx [index ]
500+
501+ # Build function call kwargs with thought_signature if available
502+ func_call_kwargs : dict [str , Any ] = {
503+ "id" : FAKE_RESPONSES_ID ,
504+ "call_id" : function_call .call_id ,
505+ "arguments" : function_call .arguments ,
506+ "name" : function_call .name ,
507+ "type" : "function_call" ,
508+ }
509+
510+ # Add thought_signature from Gemini if present
511+ if index in state .function_call_thought_signatures :
512+ func_call_kwargs ["provider_specific_fields" ] = {
513+ "google" : {
514+ "thought_signature" : state .function_call_thought_signatures [index ]
515+ }
516+ }
517+
486518 yield ResponseOutputItemDoneEvent (
487- item = ResponseFunctionToolCall (
488- id = FAKE_RESPONSES_ID ,
489- call_id = function_call .call_id ,
490- arguments = function_call .arguments ,
491- name = function_call .name ,
492- type = "function_call" ,
493- ),
519+ item = ResponseFunctionToolCall (** func_call_kwargs ),
494520 output_index = output_index ,
495521 type = "response.output_item.done" ,
496522 sequence_number = sequence_number .get_and_increment (),
@@ -511,15 +537,26 @@ async def handle_stream(
511537 1 for streaming in state .function_call_streaming .values () if streaming
512538 )
513539
540+ # Build function call kwargs with thought_signature if available
541+ fallback_func_call_kwargs : dict [str , Any ] = {
542+ "id" : FAKE_RESPONSES_ID ,
543+ "call_id" : function_call .call_id ,
544+ "arguments" : function_call .arguments ,
545+ "name" : function_call .name ,
546+ "type" : "function_call" ,
547+ }
548+
549+ # Add thought_signature from Gemini if present
550+ if index in state .function_call_thought_signatures :
551+ fallback_func_call_kwargs ["provider_specific_fields" ] = {
552+ "google" : {
553+ "thought_signature" : state .function_call_thought_signatures [index ]
554+ }
555+ }
556+
514557 # Send all events at once (backward compatibility)
515558 yield ResponseOutputItemAddedEvent (
516- item = ResponseFunctionToolCall (
517- id = FAKE_RESPONSES_ID ,
518- call_id = function_call .call_id ,
519- arguments = function_call .arguments ,
520- name = function_call .name ,
521- type = "function_call" ,
522- ),
559+ item = ResponseFunctionToolCall (** fallback_func_call_kwargs ),
523560 output_index = fallback_starting_index ,
524561 type = "response.output_item.added" ,
525562 sequence_number = sequence_number .get_and_increment (),
@@ -532,13 +569,7 @@ async def handle_stream(
532569 sequence_number = sequence_number .get_and_increment (),
533570 )
534571 yield ResponseOutputItemDoneEvent (
535- item = ResponseFunctionToolCall (
536- id = FAKE_RESPONSES_ID ,
537- call_id = function_call .call_id ,
538- arguments = function_call .arguments ,
539- name = function_call .name ,
540- type = "function_call" ,
541- ),
572+ item = ResponseFunctionToolCall (** fallback_func_call_kwargs ),
542573 output_index = fallback_starting_index ,
543574 type = "response.output_item.done" ,
544575 sequence_number = sequence_number .get_and_increment (),
@@ -587,8 +618,24 @@ async def handle_stream(
587618 sequence_number = sequence_number .get_and_increment (),
588619 )
589620
590- for function_call in state .function_calls .values ():
591- outputs .append (function_call )
621+ for index , function_call in state .function_calls .items ():
622+ # Reconstruct function call with thought_signature if available
623+ if index in state .function_call_thought_signatures :
624+ func_call_with_signature = ResponseFunctionToolCall (
625+ id = function_call .id ,
626+ call_id = function_call .call_id ,
627+ arguments = function_call .arguments ,
628+ name = function_call .name ,
629+ type = "function_call" ,
630+ provider_specific_fields = { # type: ignore[call-arg]
631+ "google" : {
632+ "thought_signature" : state .function_call_thought_signatures [index ]
633+ }
634+ },
635+ )
636+ outputs .append (func_call_with_signature )
637+ else :
638+ outputs .append (function_call )
592639
593640 final_response = response .model_copy ()
594641 final_response .output = outputs
0 commit comments