33from __future__ import annotations
44
55import asyncio
6+ import base64
67import codecs
78from collections .abc import AsyncGenerator , AsyncIterator , Callable
8- from dataclasses import replace
9+ from dataclasses import dataclass , replace
910import mimetypes
1011from pathlib import Path
11- from typing import TYPE_CHECKING , Any , cast
12+ from typing import TYPE_CHECKING , Any , Literal , cast
1213
1314from google .genai import Client
1415from google .genai .errors import APIError , ClientError
2728 PartUnionDict ,
2829 SafetySetting ,
2930 Schema ,
31+ ThinkingConfig ,
3032 Tool ,
3133 ToolListUnion ,
3234)
@@ -201,6 +203,30 @@ def _create_google_tool_response_content(
201203 )
202204
203205
206+ @dataclass (slots = True )
207+ class PartDetails :
208+ """Additional data for a content part."""
209+
210+ part_type : Literal ["text" , "thought" , "function_call" ]
211+ """The part type for which this data is relevant for."""
212+
213+ index : int
214+ """Start position or number of the tool."""
215+
216+ length : int = 0
217+ """Length of the relevant data."""
218+
219+ thought_signature : str | None = None
220+ """Base64 encoded thought signature, if available."""
221+
222+
223+ @dataclass (slots = True )
224+ class ContentDetails :
225+ """Native data for AssistantContent."""
226+
227+ part_details : list [PartDetails ]
228+
229+
204230def _convert_content (
205231 content : (
206232 conversation .UserContent
@@ -209,32 +235,91 @@ def _convert_content(
209235 ),
210236) -> Content :
211237 """Convert HA content to Google content."""
212- if content .role != "assistant" or not content .tool_calls :
213- role = "model" if content .role == "assistant" else content .role
238+ if content .role != "assistant" :
214239 return Content (
215- role = role ,
216- parts = [
217- Part .from_text (text = content .content if content .content else "" ),
218- ],
240+ role = content .role ,
241+ parts = [Part .from_text (text = content .content if content .content else "" )],
219242 )
220243
221244 # Handle the Assistant content with tool calls.
222245 assert type (content ) is conversation .AssistantContent
223246 parts : list [Part ] = []
247+ part_details : list [PartDetails ] = (
248+ content .native .part_details
249+ if isinstance (content .native , ContentDetails )
250+ else []
251+ )
252+ details : PartDetails | None = None
224253
225254 if content .content :
226- parts .append (Part .from_text (text = content .content ))
255+ index = 0
256+ for details in part_details :
257+ if details .part_type == "text" :
258+ if index < details .index :
259+ parts .append (
260+ Part .from_text (text = content .content [index : details .index ])
261+ )
262+ index = details .index
263+ parts .append (
264+ Part .from_text (
265+ text = content .content [index : index + details .length ],
266+ )
267+ )
268+ if details .thought_signature :
269+ parts [- 1 ].thought_signature = base64 .b64decode (
270+ details .thought_signature
271+ )
272+ index += details .length
273+ if index < len (content .content ):
274+ parts .append (Part .from_text (text = content .content [index :]))
275+
276+ if content .thinking_content :
277+ index = 0
278+ for details in part_details :
279+ if details .part_type == "thought" :
280+ if index < details .index :
281+ parts .append (
282+ Part .from_text (
283+ text = content .thinking_content [index : details .index ]
284+ )
285+ )
286+ parts [- 1 ].thought = True
287+ index = details .index
288+ parts .append (
289+ Part .from_text (
290+ text = content .thinking_content [index : index + details .length ],
291+ )
292+ )
293+ parts [- 1 ].thought = True
294+ if details .thought_signature :
295+ parts [- 1 ].thought_signature = base64 .b64decode (
296+ details .thought_signature
297+ )
298+ index += details .length
299+ if index < len (content .thinking_content ):
300+ parts .append (Part .from_text (text = content .thinking_content [index :]))
301+ parts [- 1 ].thought = True
227302
228303 if content .tool_calls :
229- parts . extend (
230- [
304+ for index , tool_call in enumerate ( content . tool_calls ):
305+ parts . append (
231306 Part .from_function_call (
232307 name = tool_call .tool_name ,
233308 args = _escape_decode (tool_call .tool_args ),
234309 )
235- for tool_call in content .tool_calls
236- ]
237- )
310+ )
311+ if details := next (
312+ (
313+ d
314+ for d in part_details
315+ if d .part_type == "function_call" and d .index == index
316+ ),
317+ None ,
318+ ):
319+ if details .thought_signature :
320+ parts [- 1 ].thought_signature = base64 .b64decode (
321+ details .thought_signature
322+ )
238323
239324 return Content (role = "model" , parts = parts )
240325
@@ -243,14 +328,20 @@ async def _transform_stream(
243328 result : AsyncIterator [GenerateContentResponse ],
244329) -> AsyncGenerator [conversation .AssistantContentDeltaDict ]:
245330 new_message = True
331+ part_details : list [PartDetails ] = []
246332 try :
247333 async for response in result :
248334 LOGGER .debug ("Received response chunk: %s" , response )
249- chunk : conversation .AssistantContentDeltaDict = {}
250335
251336 if new_message :
252- chunk ["role" ] = "assistant"
337+ if part_details :
338+ yield {"native" : ContentDetails (part_details = part_details )}
339+ part_details = []
340+ yield {"role" : "assistant" }
253341 new_message = False
342+ content_index = 0
343+ thinking_content_index = 0
344+ tool_call_index = 0
254345
255346 # According to the API docs, this would mean no candidate is returned, so we can safely throw an error here.
256347 if response .prompt_feedback or not response .candidates :
@@ -284,23 +375,62 @@ async def _transform_stream(
284375 else []
285376 )
286377
287- content = "" .join ([part .text for part in response_parts if part .text ])
288- tool_calls = []
289378 for part in response_parts :
290- if not part .function_call :
291- continue
292- tool_call = part .function_call
293- tool_name = tool_call .name if tool_call .name else ""
294- tool_args = _escape_decode (tool_call .args )
295- tool_calls .append (
296- llm .ToolInput (tool_name = tool_name , tool_args = tool_args )
297- )
379+ chunk : conversation .AssistantContentDeltaDict = {}
380+
381+ if part .text :
382+ if part .thought :
383+ chunk ["thinking_content" ] = part .text
384+ if part .thought_signature :
385+ part_details .append (
386+ PartDetails (
387+ part_type = "thought" ,
388+ index = thinking_content_index ,
389+ length = len (part .text ),
390+ thought_signature = base64 .b64encode (
391+ part .thought_signature
392+ ).decode ("utf-8" ),
393+ )
394+ )
395+ thinking_content_index += len (part .text )
396+ else :
397+ chunk ["content" ] = part .text
398+ if part .thought_signature :
399+ part_details .append (
400+ PartDetails (
401+ part_type = "text" ,
402+ index = content_index ,
403+ length = len (part .text ),
404+ thought_signature = base64 .b64encode (
405+ part .thought_signature
406+ ).decode ("utf-8" ),
407+ )
408+ )
409+ content_index += len (part .text )
410+
411+ if part .function_call :
412+ tool_call = part .function_call
413+ tool_name = tool_call .name if tool_call .name else ""
414+ tool_args = _escape_decode (tool_call .args )
415+ chunk ["tool_calls" ] = [
416+ llm .ToolInput (tool_name = tool_name , tool_args = tool_args )
417+ ]
418+ if part .thought_signature :
419+ part_details .append (
420+ PartDetails (
421+ part_type = "function_call" ,
422+ index = tool_call_index ,
423+ thought_signature = base64 .b64encode (
424+ part .thought_signature
425+ ).decode ("utf-8" ),
426+ )
427+ )
428+
429+ yield chunk
298430
299- if tool_calls :
300- chunk [ "tool_calls" ] = tool_calls
431+ if part_details :
432+ yield { "native" : ContentDetails ( part_details = part_details )}
301433
302- chunk ["content" ] = content
303- yield chunk
304434 except (
305435 APIError ,
306436 ValueError ,
@@ -522,6 +652,7 @@ def create_generate_content_config(self) -> GenerateContentConfig:
522652 ),
523653 ),
524654 ],
655+ thinking_config = ThinkingConfig (include_thoughts = True ),
525656 )
526657
527658
0 commit comments