Skip to content

Commit ee506e6

Browse files
Shulyakajoostlek
andauthored
Implement thinking content for Gemini (home-assistant#150347)
Co-authored-by: Joost Lekkerkerker <[email protected]>
1 parent 8003a49 commit ee506e6

File tree

4 files changed

+275
-31
lines changed

4 files changed

+275
-31
lines changed

homeassistant/components/google_generative_ai_conversation/entity.py

Lines changed: 161 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
from __future__ import annotations
44

55
import asyncio
6+
import base64
67
import codecs
78
from collections.abc import AsyncGenerator, AsyncIterator, Callable
8-
from dataclasses import replace
9+
from dataclasses import dataclass, replace
910
import mimetypes
1011
from pathlib import Path
11-
from typing import TYPE_CHECKING, Any, cast
12+
from typing import TYPE_CHECKING, Any, Literal, cast
1213

1314
from google.genai import Client
1415
from google.genai.errors import APIError, ClientError
@@ -27,6 +28,7 @@
2728
PartUnionDict,
2829
SafetySetting,
2930
Schema,
31+
ThinkingConfig,
3032
Tool,
3133
ToolListUnion,
3234
)
@@ -201,6 +203,30 @@ def _create_google_tool_response_content(
201203
)
202204

203205

206+
@dataclass(slots=True)
207+
class PartDetails:
208+
"""Additional data for a content part."""
209+
210+
part_type: Literal["text", "thought", "function_call"]
211+
"""The part type for which this data is relevant for."""
212+
213+
index: int
214+
"""Start position or number of the tool."""
215+
216+
length: int = 0
217+
"""Length of the relevant data."""
218+
219+
thought_signature: str | None = None
220+
"""Base64 encoded thought signature, if available."""
221+
222+
223+
@dataclass(slots=True)
224+
class ContentDetails:
225+
"""Native data for AssistantContent."""
226+
227+
part_details: list[PartDetails]
228+
229+
204230
def _convert_content(
205231
content: (
206232
conversation.UserContent
@@ -209,32 +235,91 @@ def _convert_content(
209235
),
210236
) -> Content:
211237
"""Convert HA content to Google content."""
212-
if content.role != "assistant" or not content.tool_calls:
213-
role = "model" if content.role == "assistant" else content.role
238+
if content.role != "assistant":
214239
return Content(
215-
role=role,
216-
parts=[
217-
Part.from_text(text=content.content if content.content else ""),
218-
],
240+
role=content.role,
241+
parts=[Part.from_text(text=content.content if content.content else "")],
219242
)
220243

221244
# Handle the Assistant content with tool calls.
222245
assert type(content) is conversation.AssistantContent
223246
parts: list[Part] = []
247+
part_details: list[PartDetails] = (
248+
content.native.part_details
249+
if isinstance(content.native, ContentDetails)
250+
else []
251+
)
252+
details: PartDetails | None = None
224253

225254
if content.content:
226-
parts.append(Part.from_text(text=content.content))
255+
index = 0
256+
for details in part_details:
257+
if details.part_type == "text":
258+
if index < details.index:
259+
parts.append(
260+
Part.from_text(text=content.content[index : details.index])
261+
)
262+
index = details.index
263+
parts.append(
264+
Part.from_text(
265+
text=content.content[index : index + details.length],
266+
)
267+
)
268+
if details.thought_signature:
269+
parts[-1].thought_signature = base64.b64decode(
270+
details.thought_signature
271+
)
272+
index += details.length
273+
if index < len(content.content):
274+
parts.append(Part.from_text(text=content.content[index:]))
275+
276+
if content.thinking_content:
277+
index = 0
278+
for details in part_details:
279+
if details.part_type == "thought":
280+
if index < details.index:
281+
parts.append(
282+
Part.from_text(
283+
text=content.thinking_content[index : details.index]
284+
)
285+
)
286+
parts[-1].thought = True
287+
index = details.index
288+
parts.append(
289+
Part.from_text(
290+
text=content.thinking_content[index : index + details.length],
291+
)
292+
)
293+
parts[-1].thought = True
294+
if details.thought_signature:
295+
parts[-1].thought_signature = base64.b64decode(
296+
details.thought_signature
297+
)
298+
index += details.length
299+
if index < len(content.thinking_content):
300+
parts.append(Part.from_text(text=content.thinking_content[index:]))
301+
parts[-1].thought = True
227302

228303
if content.tool_calls:
229-
parts.extend(
230-
[
304+
for index, tool_call in enumerate(content.tool_calls):
305+
parts.append(
231306
Part.from_function_call(
232307
name=tool_call.tool_name,
233308
args=_escape_decode(tool_call.tool_args),
234309
)
235-
for tool_call in content.tool_calls
236-
]
237-
)
310+
)
311+
if details := next(
312+
(
313+
d
314+
for d in part_details
315+
if d.part_type == "function_call" and d.index == index
316+
),
317+
None,
318+
):
319+
if details.thought_signature:
320+
parts[-1].thought_signature = base64.b64decode(
321+
details.thought_signature
322+
)
238323

239324
return Content(role="model", parts=parts)
240325

@@ -243,14 +328,20 @@ async def _transform_stream(
243328
result: AsyncIterator[GenerateContentResponse],
244329
) -> AsyncGenerator[conversation.AssistantContentDeltaDict]:
245330
new_message = True
331+
part_details: list[PartDetails] = []
246332
try:
247333
async for response in result:
248334
LOGGER.debug("Received response chunk: %s", response)
249-
chunk: conversation.AssistantContentDeltaDict = {}
250335

251336
if new_message:
252-
chunk["role"] = "assistant"
337+
if part_details:
338+
yield {"native": ContentDetails(part_details=part_details)}
339+
part_details = []
340+
yield {"role": "assistant"}
253341
new_message = False
342+
content_index = 0
343+
thinking_content_index = 0
344+
tool_call_index = 0
254345

255346
# According to the API docs, this would mean no candidate is returned, so we can safely throw an error here.
256347
if response.prompt_feedback or not response.candidates:
@@ -284,23 +375,62 @@ async def _transform_stream(
284375
else []
285376
)
286377

287-
content = "".join([part.text for part in response_parts if part.text])
288-
tool_calls = []
289378
for part in response_parts:
290-
if not part.function_call:
291-
continue
292-
tool_call = part.function_call
293-
tool_name = tool_call.name if tool_call.name else ""
294-
tool_args = _escape_decode(tool_call.args)
295-
tool_calls.append(
296-
llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
297-
)
379+
chunk: conversation.AssistantContentDeltaDict = {}
380+
381+
if part.text:
382+
if part.thought:
383+
chunk["thinking_content"] = part.text
384+
if part.thought_signature:
385+
part_details.append(
386+
PartDetails(
387+
part_type="thought",
388+
index=thinking_content_index,
389+
length=len(part.text),
390+
thought_signature=base64.b64encode(
391+
part.thought_signature
392+
).decode("utf-8"),
393+
)
394+
)
395+
thinking_content_index += len(part.text)
396+
else:
397+
chunk["content"] = part.text
398+
if part.thought_signature:
399+
part_details.append(
400+
PartDetails(
401+
part_type="text",
402+
index=content_index,
403+
length=len(part.text),
404+
thought_signature=base64.b64encode(
405+
part.thought_signature
406+
).decode("utf-8"),
407+
)
408+
)
409+
content_index += len(part.text)
410+
411+
if part.function_call:
412+
tool_call = part.function_call
413+
tool_name = tool_call.name if tool_call.name else ""
414+
tool_args = _escape_decode(tool_call.args)
415+
chunk["tool_calls"] = [
416+
llm.ToolInput(tool_name=tool_name, tool_args=tool_args)
417+
]
418+
if part.thought_signature:
419+
part_details.append(
420+
PartDetails(
421+
part_type="function_call",
422+
index=tool_call_index,
423+
thought_signature=base64.b64encode(
424+
part.thought_signature
425+
).decode("utf-8"),
426+
)
427+
)
428+
429+
yield chunk
298430

299-
if tool_calls:
300-
chunk["tool_calls"] = tool_calls
431+
if part_details:
432+
yield {"native": ContentDetails(part_details=part_details)}
301433

302-
chunk["content"] = content
303-
yield chunk
304434
except (
305435
APIError,
306436
ValueError,
@@ -522,6 +652,7 @@ def create_generate_content_config(self) -> GenerateContentConfig:
522652
),
523653
),
524654
],
655+
thinking_config=ThinkingConfig(include_thoughts=True),
525656
)
526657

527658

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
# serializer version: 1
2+
# name: test_function_call
3+
list([
4+
Content(
5+
parts=[
6+
Part(
7+
text='Please call the test function'
8+
),
9+
],
10+
role='user'
11+
),
12+
Content(
13+
parts=[
14+
Part(
15+
text='Hi there!',
16+
thought_signature=b'_thought_signature_2'
17+
),
18+
Part(
19+
text='The user asked me to call a function',
20+
thought=True,
21+
thought_signature=b'_thought_signature_1'
22+
),
23+
Part(
24+
function_call=FunctionCall(
25+
args={
26+
'param1': [
27+
'test_value',
28+
"param1's value",
29+
],
30+
'param2': 2.7
31+
},
32+
name='test_tool'
33+
),
34+
thought_signature=b'_thought_signature_3'
35+
),
36+
],
37+
role='model'
38+
),
39+
Content(
40+
parts=[
41+
Part(
42+
function_response=FunctionResponse(
43+
name='test_tool',
44+
response={
45+
'result': 'Test response'
46+
}
47+
)
48+
),
49+
],
50+
role='user'
51+
),
52+
Content(
53+
parts=[
54+
Part(
55+
text="I've called the ",
56+
thought_signature=b'_thought_signature_4'
57+
),
58+
Part(
59+
text='test function with the provided parameters.',
60+
thought_signature=b'_thought_signature_5'
61+
),
62+
],
63+
role='model'
64+
),
65+
])
66+
# ---

0 commit comments

Comments
 (0)