Skip to content

Commit 51c35eb

Browse files
authored
Move default conversation agent to manager (home-assistant#152479)
1 parent 24a86d0 commit 51c35eb

File tree

12 files changed

+73
-83
lines changed

12 files changed

+73
-83
lines changed

homeassistant/components/conversation/__init__.py

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -50,14 +50,13 @@
5050
ATTR_LANGUAGE,
5151
ATTR_TEXT,
5252
DATA_COMPONENT,
53-
DATA_DEFAULT_ENTITY,
5453
DOMAIN,
5554
HOME_ASSISTANT_AGENT,
5655
SERVICE_PROCESS,
5756
SERVICE_RELOAD,
5857
ConversationEntityFeature,
5958
)
60-
from .default_agent import DefaultAgent, async_setup_default_agent
59+
from .default_agent import async_setup_default_agent
6160
from .entity import ConversationEntity
6261
from .http import async_setup as async_setup_conversation_http
6362
from .models import AbstractConversationAgent, ConversationInput, ConversationResult
@@ -142,7 +141,7 @@ def async_unset_agent(
142141
hass: HomeAssistant,
143142
config_entry: ConfigEntry,
144143
) -> None:
145-
"""Set the agent to handle the conversations."""
144+
"""Unset the agent to handle the conversations."""
146145
get_agent_manager(hass).async_unset_agent(config_entry.entry_id)
147146

148147

@@ -241,10 +240,10 @@ async def async_handle_sentence_triggers(
241240
242241
Returns None if no match occurred.
243242
"""
244-
default_agent = async_get_agent(hass)
245-
assert isinstance(default_agent, DefaultAgent)
243+
agent = get_agent_manager(hass).default_agent
244+
assert agent is not None
246245

247-
return await default_agent.async_handle_sentence_triggers(user_input)
246+
return await agent.async_handle_sentence_triggers(user_input)
248247

249248

250249
async def async_handle_intents(
@@ -257,12 +256,10 @@ async def async_handle_intents(
257256
258257
Returns None if no match occurred.
259258
"""
260-
default_agent = async_get_agent(hass)
261-
assert isinstance(default_agent, DefaultAgent)
259+
agent = get_agent_manager(hass).default_agent
260+
assert agent is not None
262261

263-
return await default_agent.async_handle_intents(
264-
user_input, intent_filter=intent_filter
265-
)
262+
return await agent.async_handle_intents(user_input, intent_filter=intent_filter)
266263

267264

268265
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
@@ -298,9 +295,9 @@ async def handle_process(service: ServiceCall) -> ServiceResponse:
298295

299296
async def handle_reload(service: ServiceCall) -> None:
300297
"""Reload intents."""
301-
await hass.data[DATA_DEFAULT_ENTITY].async_reload(
302-
language=service.data.get(ATTR_LANGUAGE)
303-
)
298+
agent = get_agent_manager(hass).default_agent
299+
if agent is not None:
300+
await agent.async_reload(language=service.data.get(ATTR_LANGUAGE))
304301

305302
hass.services.async_register(
306303
DOMAIN,

homeassistant/components/conversation/agent_manager.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44

55
import dataclasses
66
import logging
7-
from typing import Any
7+
from typing import TYPE_CHECKING, Any
88

99
import voluptuous as vol
1010

1111
from homeassistant.core import Context, HomeAssistant, async_get_hass, callback
1212
from homeassistant.exceptions import HomeAssistantError
1313
from homeassistant.helpers import config_validation as cv, intent, singleton
1414

15-
from .const import DATA_COMPONENT, DATA_DEFAULT_ENTITY, HOME_ASSISTANT_AGENT
15+
from .const import DATA_COMPONENT, HOME_ASSISTANT_AGENT
1616
from .entity import ConversationEntity
1717
from .models import (
1818
AbstractConversationAgent,
@@ -28,6 +28,9 @@
2828

2929
_LOGGER = logging.getLogger(__name__)
3030

31+
if TYPE_CHECKING:
32+
from .default_agent import DefaultAgent
33+
3134

3235
@singleton.singleton("conversation_agent")
3336
@callback
@@ -49,8 +52,10 @@ def async_get_agent(
4952
hass: HomeAssistant, agent_id: str | None = None
5053
) -> AbstractConversationAgent | ConversationEntity | None:
5154
"""Get specified agent."""
55+
manager = get_agent_manager(hass)
56+
5257
if agent_id is None or agent_id == HOME_ASSISTANT_AGENT:
53-
return hass.data[DATA_DEFAULT_ENTITY]
58+
return manager.default_agent
5459

5560
if "." in agent_id:
5661
return hass.data[DATA_COMPONENT].get_entity(agent_id)
@@ -134,6 +139,7 @@ def __init__(self, hass: HomeAssistant) -> None:
134139
"""Initialize the conversation agents."""
135140
self.hass = hass
136141
self._agents: dict[str, AbstractConversationAgent] = {}
142+
self.default_agent: DefaultAgent | None = None
137143

138144
@callback
139145
def async_get_agent(self, agent_id: str) -> AbstractConversationAgent | None:
@@ -182,3 +188,7 @@ def async_set_agent(self, agent_id: str, agent: AbstractConversationAgent) -> No
182188
def async_unset_agent(self, agent_id: str) -> None:
183189
"""Unset the agent."""
184190
self._agents.pop(agent_id, None)
191+
192+
async def async_setup_default_agent(self, agent: DefaultAgent) -> None:
193+
"""Set up the default agent."""
194+
self.default_agent = agent

homeassistant/components/conversation/const.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,9 @@
1010
if TYPE_CHECKING:
1111
from homeassistant.helpers.entity_component import EntityComponent
1212

13-
from .default_agent import DefaultAgent
1413
from .entity import ConversationEntity
1514

1615
DOMAIN = "conversation"
17-
DEFAULT_EXPOSED_ATTRIBUTES = {"device_class"}
1816
HOME_ASSISTANT_AGENT = "conversation.home_assistant"
1917

2018
ATTR_TEXT = "text"
@@ -26,7 +24,6 @@
2624
SERVICE_RELOAD = "reload"
2725

2826
DATA_COMPONENT: HassKey[EntityComponent[ConversationEntity]] = HassKey(DOMAIN)
29-
DATA_DEFAULT_ENTITY: HassKey[DefaultAgent] = HassKey(f"{DOMAIN}_default_entity")
3027

3128

3229
class ConversationEntityFeature(IntFlag):

homeassistant/components/conversation/default_agent.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,9 @@
6868
from homeassistant.util import language as language_util
6969
from homeassistant.util.json import JsonObjectType, json_loads_object
7070

71+
from .agent_manager import get_agent_manager
7172
from .chat_log import AssistantContent, ChatLog
72-
from .const import (
73-
DATA_DEFAULT_ENTITY,
74-
DEFAULT_EXPOSED_ATTRIBUTES,
75-
DOMAIN,
76-
ConversationEntityFeature,
77-
)
73+
from .const import DOMAIN, ConversationEntityFeature
7874
from .entity import ConversationEntity
7975
from .models import ConversationInput, ConversationResult
8076
from .trace import ConversationTraceEventType, async_conversation_trace_append
@@ -83,6 +79,8 @@
8379
_DEFAULT_ERROR_TEXT = "Sorry, I couldn't understand that"
8480
_ENTITY_REGISTRY_UPDATE_FIELDS = ["aliases", "name", "original_name"]
8581

82+
_DEFAULT_EXPOSED_ATTRIBUTES = {"device_class"}
83+
8684
REGEX_TYPE = type(re.compile(""))
8785
TRIGGER_CALLBACK_TYPE = Callable[
8886
[ConversationInput, RecognizeResult], Awaitable[str | None]
@@ -209,9 +207,9 @@ async def async_setup_default_agent(
209207
config_intents: dict[str, Any],
210208
) -> None:
211209
"""Set up entity registry listener for the default agent."""
212-
entity = DefaultAgent(hass, config_intents)
213-
await entity_component.async_add_entities([entity])
214-
hass.data[DATA_DEFAULT_ENTITY] = entity
210+
agent = DefaultAgent(hass, config_intents)
211+
await entity_component.async_add_entities([agent])
212+
await get_agent_manager(hass).async_setup_default_agent(agent)
215213

216214
@core.callback
217215
def async_entity_state_listener(
@@ -846,7 +844,7 @@ def _get_entity_name_tuples(
846844
context = {"domain": state.domain}
847845
if state.attributes:
848846
# Include some attributes
849-
for attr in DEFAULT_EXPOSED_ATTRIBUTES:
847+
for attr in _DEFAULT_EXPOSED_ATTRIBUTES:
850848
if attr not in state.attributes:
851849
continue
852850
context[attr] = state.attributes[attr]

homeassistant/components/conversation/http.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
async_get_agent,
2626
get_agent_manager,
2727
)
28-
from .const import DATA_COMPONENT, DATA_DEFAULT_ENTITY
28+
from .const import DATA_COMPONENT
2929
from .default_agent import (
3030
METADATA_CUSTOM_FILE,
3131
METADATA_CUSTOM_SENTENCE,
@@ -169,7 +169,8 @@ async def websocket_list_sentences(
169169
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
170170
) -> None:
171171
"""List custom registered sentences."""
172-
agent = hass.data[DATA_DEFAULT_ENTITY]
172+
agent = get_agent_manager(hass).default_agent
173+
assert agent is not None
173174

174175
sentences = []
175176
for trigger_data in agent.trigger_sentences:
@@ -191,7 +192,8 @@ async def websocket_hass_agent_debug(
191192
hass: HomeAssistant, connection: websocket_api.ActiveConnection, msg: dict
192193
) -> None:
193194
"""Return intents that would be matched by the default agent for a list of sentences."""
194-
agent = hass.data[DATA_DEFAULT_ENTITY]
195+
agent = get_agent_manager(hass).default_agent
196+
assert agent is not None
195197

196198
# Return results for each sentence in the same order as the input.
197199
result_dicts: list[dict[str, Any] | None] = []

homeassistant/components/conversation/trigger.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,8 @@
2020
from homeassistant.helpers.trigger import TriggerActionType, TriggerInfo
2121
from homeassistant.helpers.typing import UNDEFINED, ConfigType
2222

23-
from .const import DATA_DEFAULT_ENTITY, DOMAIN
23+
from .agent_manager import get_agent_manager
24+
from .const import DOMAIN
2425
from .models import ConversationInput
2526

2627

@@ -123,4 +124,6 @@ async def call_action(
123124
# two trigger copies for who will provide a response.
124125
return None
125126

126-
return hass.data[DATA_DEFAULT_ENTITY].register_trigger(sentences, call_action)
127+
agent = get_agent_manager(hass).default_agent
128+
assert agent is not None
129+
return agent.register_trigger(sentences, call_action)

tests/components/assist_pipeline/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ async def async_setup_entry_wake_word_platform(
298298
assert await async_setup_component(hass, "conversation", {"conversation": {}})
299299

300300
# Disable fuzzy matching by default for tests
301-
agent = hass.data[conversation.DATA_DEFAULT_ENTITY]
301+
agent = conversation.async_get_agent(hass)
302302
agent.fuzzy_matching = False
303303

304304
config_entry = MockConfigEntry(domain="test")

tests/components/conversation/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import pytest
77

88
from homeassistant.components import conversation
9+
from homeassistant.components.conversation import async_get_agent, default_agent
910
from homeassistant.components.shopping_list import intent as sl_intent
1011
from homeassistant.const import MATCH_ALL
1112
from homeassistant.core import Context, HomeAssistant
@@ -77,5 +78,6 @@ async def init_components(hass: HomeAssistant):
7778
assert await async_setup_component(hass, "conversation", {conversation.DOMAIN: {}})
7879

7980
# Disable fuzzy matching by default for tests
80-
agent = hass.data[conversation.DATA_DEFAULT_ENTITY]
81+
agent = async_get_agent(hass)
82+
assert isinstance(agent, default_agent.DefaultAgent)
8183
agent.fuzzy_matching = False

tests/components/conversation/test_default_agent.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@
1212
import yaml
1313

1414
from homeassistant.components import conversation, cover, media_player, weather
15-
from homeassistant.components.conversation import default_agent
16-
from homeassistant.components.conversation.const import DATA_DEFAULT_ENTITY
15+
from homeassistant.components.conversation import async_get_agent, default_agent
1716
from homeassistant.components.conversation.default_agent import METADATA_CUSTOM_SENTENCE
1817
from homeassistant.components.conversation.models import ConversationInput
1918
from homeassistant.components.cover import SERVICE_OPEN_COVER
@@ -87,7 +86,8 @@ async def init_components(hass: HomeAssistant) -> None:
8786
assert await async_setup_component(hass, "intent", {})
8887

8988
# Disable fuzzy matching by default for tests
90-
agent = hass.data[DATA_DEFAULT_ENTITY]
89+
agent = async_get_agent(hass)
90+
assert isinstance(agent, default_agent.DefaultAgent)
9191
agent.fuzzy_matching = False
9292

9393

@@ -215,7 +215,7 @@ async def test_exposed_areas(
215215
@pytest.mark.usefixtures("init_components")
216216
async def test_conversation_agent(hass: HomeAssistant) -> None:
217217
"""Test DefaultAgent."""
218-
agent = hass.data[DATA_DEFAULT_ENTITY]
218+
agent = async_get_agent(hass)
219219
with patch(
220220
"homeassistant.components.conversation.default_agent.get_languages",
221221
return_value=["dwarvish", "elvish", "entish"],
@@ -415,8 +415,7 @@ async def test_trigger_sentences(hass: HomeAssistant) -> None:
415415
trigger_sentences = ["It's party time", "It is time to party"]
416416
trigger_response = "Cowabunga!"
417417

418-
agent = hass.data[DATA_DEFAULT_ENTITY]
419-
assert isinstance(agent, default_agent.DefaultAgent)
418+
agent = async_get_agent(hass)
420419

421420
callback = AsyncMock(return_value=trigger_response)
422421
unregister = agent.register_trigger(trigger_sentences, callback)
@@ -462,8 +461,7 @@ async def test_trigger_sentence_response_translation(
462461
"""Test translation of default response 'done'."""
463462
hass.config.language = language
464463

465-
agent = hass.data[DATA_DEFAULT_ENTITY]
466-
assert isinstance(agent, default_agent.DefaultAgent)
464+
agent = async_get_agent(hass)
467465

468466
translations = {
469467
"en": {"component.conversation.conversation.agent.done": "English done"},
@@ -2525,8 +2523,7 @@ async def test_non_default_response(hass: HomeAssistant, init_components) -> Non
25252523
hass.states.async_set("cover.front_door", "closed")
25262524
calls = async_mock_service(hass, "cover", SERVICE_OPEN_COVER)
25272525

2528-
agent = hass.data[DATA_DEFAULT_ENTITY]
2529-
assert isinstance(agent, default_agent.DefaultAgent)
2526+
agent = async_get_agent(hass)
25302527

25312528
result = await agent.async_process(
25322529
ConversationInput(
@@ -2872,8 +2869,7 @@ async def test_query_same_name_different_areas(
28722869
@pytest.mark.usefixtures("init_components")
28732870
async def test_intent_cache_exposed(hass: HomeAssistant) -> None:
28742871
"""Test that intent recognition results are cached for exposed entities."""
2875-
agent = hass.data[DATA_DEFAULT_ENTITY]
2876-
assert isinstance(agent, default_agent.DefaultAgent)
2872+
agent = async_get_agent(hass)
28772873

28782874
entity_id = "light.test_light"
28792875
hass.states.async_set(entity_id, "off")
@@ -2912,8 +2908,7 @@ async def test_intent_cache_exposed(hass: HomeAssistant) -> None:
29122908
@pytest.mark.usefixtures("init_components")
29132909
async def test_intent_cache_all_entities(hass: HomeAssistant) -> None:
29142910
"""Test that intent recognition results are cached for all entities."""
2915-
agent = hass.data[DATA_DEFAULT_ENTITY]
2916-
assert isinstance(agent, default_agent.DefaultAgent)
2911+
agent = async_get_agent(hass)
29172912

29182913
entity_id = "light.test_light"
29192914
hass.states.async_set(entity_id, "off")
@@ -2952,8 +2947,7 @@ async def test_intent_cache_all_entities(hass: HomeAssistant) -> None:
29522947
@pytest.mark.usefixtures("init_components")
29532948
async def test_intent_cache_fuzzy(hass: HomeAssistant) -> None:
29542949
"""Test that intent recognition results are cached for fuzzy matches."""
2955-
agent = hass.data[DATA_DEFAULT_ENTITY]
2956-
assert isinstance(agent, default_agent.DefaultAgent)
2950+
agent = async_get_agent(hass)
29572951

29582952
# There is no entity named test light
29592953
user_input = ConversationInput(
@@ -2982,8 +2976,7 @@ async def test_intent_cache_fuzzy(hass: HomeAssistant) -> None:
29822976
@pytest.mark.usefixtures("init_components")
29832977
async def test_entities_filtered_by_input(hass: HomeAssistant) -> None:
29842978
"""Test that entities are filtered by the input text before intent matching."""
2985-
agent = hass.data[DATA_DEFAULT_ENTITY]
2986-
assert isinstance(agent, default_agent.DefaultAgent)
2979+
agent = async_get_agent(hass)
29872980

29882981
# Only the switch is exposed
29892982
hass.states.async_set("light.test_light", "off")
@@ -3165,7 +3158,7 @@ async def test_handle_intents_with_response_errors(
31653158
assert await async_setup_component(hass, "climate", {})
31663159
area_registry.async_create("living room")
31673160

3168-
agent: default_agent.DefaultAgent = hass.data[DATA_DEFAULT_ENTITY]
3161+
agent = async_get_agent(hass)
31693162

31703163
user_input = ConversationInput(
31713164
text="What is the temperature in the living room?",
@@ -3203,7 +3196,7 @@ async def test_handle_intents_filters_results(
32033196
assert await async_setup_component(hass, "climate", {})
32043197
area_registry.async_create("living room")
32053198

3206-
agent: default_agent.DefaultAgent = hass.data[DATA_DEFAULT_ENTITY]
3199+
agent = async_get_agent(hass)
32073200

32083201
user_input = ConversationInput(
32093202
text="What is the temperature in the living room?",
@@ -3363,7 +3356,7 @@ async def test_fuzzy_matching(
33633356
assert await async_setup_component(hass, "intent", {})
33643357
await light_intent.async_setup_intents(hass)
33653358

3366-
agent = hass.data[DATA_DEFAULT_ENTITY]
3359+
agent = async_get_agent(hass)
33673360
agent.fuzzy_matching = fuzzy_matching
33683361

33693362
area_office = area_registry.async_get_or_create("office_id")

0 commit comments

Comments
 (0)