Skip to content

Commit 9f3dae6

Browse files
authored
Add tools in default agent also in fallback pipeline (home-assistant#157441)
1 parent ef36d7b commit 9f3dae6

File tree

5 files changed

+185
-159
lines changed

5 files changed

+185
-159
lines changed

homeassistant/components/assist_pipeline/pipeline.py

Lines changed: 58 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,63 +1123,6 @@ async def recognize_intent(
11231123
)
11241124

11251125
try:
1126-
user_input = conversation.ConversationInput(
1127-
text=intent_input,
1128-
context=self.context,
1129-
conversation_id=conversation_id,
1130-
device_id=self._device_id,
1131-
satellite_id=self._satellite_id,
1132-
language=input_language,
1133-
agent_id=self.intent_agent.id,
1134-
extra_system_prompt=conversation_extra_system_prompt,
1135-
)
1136-
1137-
agent_id = self.intent_agent.id
1138-
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
1139-
all_targets_in_satellite_area = False
1140-
intent_response: intent.IntentResponse | None = None
1141-
if not processed_locally and not self._intent_agent_only:
1142-
# Sentence triggers override conversation agent
1143-
if (
1144-
trigger_response_text
1145-
:= await conversation.async_handle_sentence_triggers(
1146-
self.hass, user_input
1147-
)
1148-
) is not None:
1149-
# Sentence trigger matched
1150-
agent_id = "sentence_trigger"
1151-
processed_locally = True
1152-
intent_response = intent.IntentResponse(
1153-
self.pipeline.conversation_language
1154-
)
1155-
intent_response.async_set_speech(trigger_response_text)
1156-
1157-
intent_filter: Callable[[RecognizeResult], bool] | None = None
1158-
# If the LLM has API access, we filter out some sentences that are
1159-
# interfering with LLM operation.
1160-
if (
1161-
intent_agent_state := self.hass.states.get(self.intent_agent.id)
1162-
) and intent_agent_state.attributes.get(
1163-
ATTR_SUPPORTED_FEATURES, 0
1164-
) & conversation.ConversationEntityFeature.CONTROL:
1165-
intent_filter = _async_local_fallback_intent_filter
1166-
1167-
# Try local intents
1168-
if (
1169-
intent_response is None
1170-
and self.pipeline.prefer_local_intents
1171-
and (
1172-
intent_response := await conversation.async_handle_intents(
1173-
self.hass,
1174-
user_input,
1175-
intent_filter=intent_filter,
1176-
)
1177-
)
1178-
):
1179-
# Local intent matched
1180-
agent_id = conversation.HOME_ASSISTANT_AGENT
1181-
processed_locally = True
1182-
11831126
if self.tts_stream and self.tts_stream.supports_streaming_input:
11841127
tts_input_stream: asyncio.Queue[str | None] | None = asyncio.Queue()
11851128
else:
@@ -1265,6 +1208,17 @@ async def tts_input_stream_generator() -> AsyncGenerator[str]:
12651208
assert self.tts_stream is not None
12661209
self.tts_stream.async_set_message_stream(tts_input_stream_generator())
12671210

1211+
user_input = conversation.ConversationInput(
1212+
text=intent_input,
1213+
context=self.context,
1214+
conversation_id=conversation_id,
1215+
device_id=self._device_id,
1216+
satellite_id=self._satellite_id,
1217+
language=input_language,
1218+
agent_id=self.intent_agent.id,
1219+
extra_system_prompt=conversation_extra_system_prompt,
1220+
)
1221+
12681222
with (
12691223
chat_session.async_get_chat_session(
12701224
self.hass, user_input.conversation_id
@@ -1276,6 +1230,53 @@ async def tts_input_stream_generator() -> AsyncGenerator[str]:
12761230
chat_log_delta_listener=chat_log_delta_listener,
12771231
) as chat_log,
12781232
):
1233+
agent_id = self.intent_agent.id
1234+
processed_locally = agent_id == conversation.HOME_ASSISTANT_AGENT
1235+
all_targets_in_satellite_area = False
1236+
intent_response: intent.IntentResponse | None = None
1237+
if not processed_locally and not self._intent_agent_only:
1238+
# Sentence triggers override conversation agent
1239+
if (
1240+
trigger_response_text
1241+
:= await conversation.async_handle_sentence_triggers(
1242+
self.hass, user_input, chat_log
1243+
)
1244+
) is not None:
1245+
# Sentence trigger matched
1246+
agent_id = "sentence_trigger"
1247+
processed_locally = True
1248+
intent_response = intent.IntentResponse(
1249+
self.pipeline.conversation_language
1250+
)
1251+
intent_response.async_set_speech(trigger_response_text)
1252+
1253+
intent_filter: Callable[[RecognizeResult], bool] | None = None
1254+
# If the LLM has API access, we filter out some sentences that are
1255+
# interfering with LLM operation.
1256+
if (
1257+
intent_agent_state := self.hass.states.get(self.intent_agent.id)
1258+
) and intent_agent_state.attributes.get(
1259+
ATTR_SUPPORTED_FEATURES, 0
1260+
) & conversation.ConversationEntityFeature.CONTROL:
1261+
intent_filter = _async_local_fallback_intent_filter
1262+
1263+
# Try local intents
1264+
if (
1265+
intent_response is None
1266+
and self.pipeline.prefer_local_intents
1267+
and (
1268+
intent_response := await conversation.async_handle_intents(
1269+
self.hass,
1270+
user_input,
1271+
chat_log,
1272+
intent_filter=intent_filter,
1273+
)
1274+
)
1275+
):
1276+
# Local intent matched
1277+
agent_id = conversation.HOME_ASSISTANT_AGENT
1278+
processed_locally = True
1279+
12791280
# It was already handled, create response and add to chat history
12801281
if intent_response is not None:
12811282
speech: str = intent_response.speech.get("plain", {}).get(

homeassistant/components/conversation/__init__.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,9 @@ async def async_prepare_agent(
236236

237237

238238
async def async_handle_sentence_triggers(
239-
hass: HomeAssistant, user_input: ConversationInput
239+
hass: HomeAssistant,
240+
user_input: ConversationInput,
241+
chat_log: ChatLog,
240242
) -> str | None:
241243
"""Try to match input against sentence triggers and return response text.
242244
@@ -245,12 +247,13 @@ async def async_handle_sentence_triggers(
245247
agent = get_agent_manager(hass).default_agent
246248
assert agent is not None
247249

248-
return await agent.async_handle_sentence_triggers(user_input)
250+
return await agent.async_handle_sentence_triggers(user_input, chat_log)
249251

250252

251253
async def async_handle_intents(
252254
hass: HomeAssistant,
253255
user_input: ConversationInput,
256+
chat_log: ChatLog,
254257
*,
255258
intent_filter: Callable[[RecognizeResult], bool] | None = None,
256259
) -> intent.IntentResponse | None:
@@ -261,7 +264,9 @@ async def async_handle_intents(
261264
agent = get_agent_manager(hass).default_agent
262265
assert agent is not None
263266

264-
return await agent.async_handle_intents(user_input, intent_filter=intent_filter)
267+
return await agent.async_handle_intents(
268+
user_input, chat_log, intent_filter=intent_filter
269+
)
265270

266271

267272
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:

homeassistant/components/conversation/default_agent.py

Lines changed: 61 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -431,26 +431,14 @@ async def _async_handle_message(
431431
) -> ConversationResult:
432432
"""Handle a message."""
433433
response: intent.IntentResponse | None = None
434-
tool_input: llm.ToolInput | None = None
435-
tool_result: dict[str, Any] = {}
436434

437435
# Check if a trigger matched
438436
if trigger_result := await self.async_recognize_sentence_trigger(user_input):
439437
# Process callbacks and get response
440438
response_text = await self._handle_trigger_result(
441-
trigger_result, user_input
439+
trigger_result, user_input, chat_log
442440
)
443441

444-
# Create tool result
445-
tool_input = llm.ToolInput(
446-
tool_name="trigger_sentence",
447-
tool_args={},
448-
external=True,
449-
)
450-
tool_result = {
451-
"response": response_text,
452-
}
453-
454442
# Convert to conversation result
455443
response = intent.IntentResponse(
456444
language=user_input.language or self.hass.config.language
@@ -462,40 +450,7 @@ async def _async_handle_message(
462450
intent_result = await self.async_recognize_intent(user_input)
463451

464452
response = await self._async_process_intent_result(
465-
intent_result, user_input
466-
)
467-
468-
if response.response_type != intent.IntentResponseType.ERROR:
469-
assert intent_result is not None
470-
assert intent_result.intent is not None
471-
# Create external tool call for the intent
472-
tool_input = llm.ToolInput(
473-
tool_name=intent_result.intent.name,
474-
tool_args={
475-
entity.name: entity.value or entity.text
476-
for entity in intent_result.entities_list
477-
},
478-
external=True,
479-
)
480-
# Create tool result from intent response
481-
tool_result = llm.IntentResponseDict(response)
482-
483-
# Add tool call and result to chat log if we have one
484-
if tool_input is not None:
485-
chat_log.async_add_assistant_content_without_tools(
486-
AssistantContent(
487-
agent_id=user_input.agent_id,
488-
content=None,
489-
tool_calls=[tool_input],
490-
)
491-
)
492-
chat_log.async_add_assistant_content_without_tools(
493-
ToolResultContent(
494-
agent_id=user_input.agent_id,
495-
tool_call_id=tool_input.id,
496-
tool_name=tool_input.tool_name,
497-
tool_result=tool_result,
498-
)
453+
intent_result, user_input, chat_log
499454
)
500455

501456
speech: str = response.speech.get("plain", {}).get("speech", "")
@@ -514,6 +469,7 @@ async def _async_process_intent_result(
514469
self,
515470
result: RecognizeResult | None,
516471
user_input: ConversationInput,
472+
chat_log: ChatLog,
517473
) -> intent.IntentResponse:
518474
"""Process user input with intents."""
519475
language = user_input.language or self.hass.config.language
@@ -576,12 +532,21 @@ async def _async_process_intent_result(
576532
ConversationTraceEventType.TOOL_CALL,
577533
{
578534
"intent_name": result.intent.name,
579-
"slots": {
580-
entity.name: entity.value or entity.text
581-
for entity in result.entities_list
582-
},
535+
"slots": {entity.name: entity.value for entity in result.entities_list},
583536
},
584537
)
538+
tool_input = llm.ToolInput(
539+
tool_name=result.intent.name,
540+
tool_args={entity.name: entity.value for entity in result.entities_list},
541+
external=True,
542+
)
543+
chat_log.async_add_assistant_content_without_tools(
544+
AssistantContent(
545+
agent_id=user_input.agent_id,
546+
content=None,
547+
tool_calls=[tool_input],
548+
)
549+
)
585550

586551
try:
587552
intent_response = await intent.async_handle(
@@ -644,6 +609,16 @@ async def _async_process_intent_result(
644609
)
645610
intent_response.async_set_speech(speech)
646611

612+
tool_result = llm.IntentResponseDict(intent_response)
613+
chat_log.async_add_assistant_content_without_tools(
614+
ToolResultContent(
615+
agent_id=user_input.agent_id,
616+
tool_call_id=tool_input.id,
617+
tool_name=tool_input.tool_name,
618+
tool_result=tool_result,
619+
)
620+
)
621+
647622
return intent_response
648623

649624
def _recognize(
@@ -1570,16 +1545,31 @@ async def async_recognize_sentence_trigger(
15701545
)
15711546

15721547
async def _handle_trigger_result(
1573-
self, result: SentenceTriggerResult, user_input: ConversationInput
1548+
self,
1549+
result: SentenceTriggerResult,
1550+
user_input: ConversationInput,
1551+
chat_log: ChatLog,
15741552
) -> str:
15751553
"""Run sentence trigger callbacks and return response text."""
1576-
15771554
# Gather callback responses in parallel
15781555
trigger_callbacks = [
15791556
self._triggers_details[trigger_id].callback(user_input, trigger_result)
15801557
for trigger_id, trigger_result in result.matched_triggers.items()
15811558
]
15821559

1560+
tool_input = llm.ToolInput(
1561+
tool_name="trigger_sentence",
1562+
tool_args={},
1563+
external=True,
1564+
)
1565+
chat_log.async_add_assistant_content_without_tools(
1566+
AssistantContent(
1567+
agent_id=user_input.agent_id,
1568+
content=None,
1569+
tool_calls=[tool_input],
1570+
)
1571+
)
1572+
15831573
# Use first non-empty result as response.
15841574
#
15851575
# There may be multiple copies of a trigger running when editing in
@@ -1608,23 +1598,38 @@ async def _handle_trigger_result(
16081598
f"component.{DOMAIN}.conversation.agent.done", "Done"
16091599
)
16101600

1601+
tool_result: dict[str, Any] = {"response": response_text}
1602+
chat_log.async_add_assistant_content_without_tools(
1603+
ToolResultContent(
1604+
agent_id=user_input.agent_id,
1605+
tool_call_id=tool_input.id,
1606+
tool_name=tool_input.tool_name,
1607+
tool_result=tool_result,
1608+
)
1609+
)
1610+
16111611
return response_text
16121612

16131613
async def async_handle_sentence_triggers(
1614-
self, user_input: ConversationInput
1614+
self,
1615+
user_input: ConversationInput,
1616+
chat_log: ChatLog,
16151617
) -> str | None:
16161618
"""Try to input sentence against sentence triggers and return response text.
16171619
16181620
Returns None if no match occurred.
16191621
"""
16201622
if trigger_result := await self.async_recognize_sentence_trigger(user_input):
1621-
return await self._handle_trigger_result(trigger_result, user_input)
1623+
return await self._handle_trigger_result(
1624+
trigger_result, user_input, chat_log
1625+
)
16221626

16231627
return None
16241628

16251629
async def async_handle_intents(
16261630
self,
16271631
user_input: ConversationInput,
1632+
chat_log: ChatLog,
16281633
*,
16291634
intent_filter: Callable[[RecognizeResult], bool] | None = None,
16301635
) -> intent.IntentResponse | None:
@@ -1640,7 +1645,7 @@ async def async_handle_intents(
16401645
# No error message on failed match
16411646
return None
16421647

1643-
response = await self._async_process_intent_result(result, user_input)
1648+
response = await self._async_process_intent_result(result, user_input, chat_log)
16441649
if (
16451650
response.response_type == intent.IntentResponseType.ERROR
16461651
and response.error_code

0 commit comments

Comments
 (0)