Skip to content

Commit cf22388

Browse files
arturpragaczfrenck
authored andcommitted
Use satellite entity area in the assist pipeline (home-assistant#153017)
1 parent 4058ca5 commit cf22388

File tree

2 files changed

+57
-23
lines changed

2 files changed

+57
-23
lines changed

homeassistant/components/assist_pipeline/pipeline.py

Lines changed: 43 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1308,7 +1308,9 @@ async def tts_input_stream_generator() -> AsyncGenerator[str]:
13081308
# instead of a full response.
13091309
all_targets_in_satellite_area = (
13101310
self._get_all_targets_in_satellite_area(
1311-
conversation_result.response, self._device_id
1311+
conversation_result.response,
1312+
self._satellite_id,
1313+
self._device_id,
13121314
)
13131315
)
13141316

@@ -1337,39 +1339,62 @@ async def tts_input_stream_generator() -> AsyncGenerator[str]:
13371339
return (speech, all_targets_in_satellite_area)
13381340

13391341
def _get_all_targets_in_satellite_area(
1340-
self, intent_response: intent.IntentResponse, device_id: str | None
1342+
self,
1343+
intent_response: intent.IntentResponse,
1344+
satellite_id: str | None,
1345+
device_id: str | None,
13411346
) -> bool:
13421347
"""Return true if all targeted entities were in the same area as the device."""
13431348
if (
1344-
(intent_response.response_type != intent.IntentResponseType.ACTION_DONE)
1345-
or (not intent_response.matched_states)
1346-
or (not device_id)
1349+
intent_response.response_type != intent.IntentResponseType.ACTION_DONE
1350+
or not intent_response.matched_states
13471351
):
13481352
return False
13491353

1354+
entity_registry = er.async_get(self.hass)
13501355
device_registry = dr.async_get(self.hass)
13511356

1352-
if (not (device := device_registry.async_get(device_id))) or (
1353-
not device.area_id
1357+
area_id: str | None = None
1358+
1359+
if (
1360+
satellite_id is not None
1361+
and (target_entity_entry := entity_registry.async_get(satellite_id))
1362+
is not None
13541363
):
1355-
return False
1364+
area_id = target_entity_entry.area_id
1365+
device_id = target_entity_entry.device_id
1366+
1367+
if area_id is None:
1368+
if device_id is None:
1369+
return False
1370+
1371+
device_entry = device_registry.async_get(device_id)
1372+
if device_entry is None:
1373+
return False
1374+
1375+
area_id = device_entry.area_id
1376+
if area_id is None:
1377+
return False
13561378

1357-
entity_registry = er.async_get(self.hass)
13581379
for state in intent_response.matched_states:
1359-
entity = entity_registry.async_get(state.entity_id)
1360-
if not entity:
1380+
target_entity_entry = entity_registry.async_get(state.entity_id)
1381+
if target_entity_entry is None:
13611382
return False
13621383

1363-
if (entity_area_id := entity.area_id) is None:
1364-
if (entity.device_id is None) or (
1365-
(entity_device := device_registry.async_get(entity.device_id))
1366-
is None
1367-
):
1384+
target_area_id = target_entity_entry.area_id
1385+
if target_area_id is None:
1386+
if target_entity_entry.device_id is None:
1387+
return False
1388+
1389+
target_device_entry = device_registry.async_get(
1390+
target_entity_entry.device_id
1391+
)
1392+
if target_device_entry is None:
13681393
return False
13691394

1370-
entity_area_id = entity_device.area_id
1395+
target_area_id = target_device_entry.area_id
13711396

1372-
if entity_area_id != device.area_id:
1397+
if target_area_id != area_id:
13731398
return False
13741399

13751400
return True

tests/components/assist_pipeline/test_pipeline.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1797,6 +1797,7 @@ async def stream_llm_response():
17971797
assert process_events(events) == snapshot
17981798

17991799

1800+
@pytest.mark.parametrize(("use_satellite_entity"), [True, False])
18001801
async def test_acknowledge(
18011802
hass: HomeAssistant,
18021803
init_components,
@@ -1805,6 +1806,7 @@ async def test_acknowledge(
18051806
entity_registry: er.EntityRegistry,
18061807
area_registry: ar.AreaRegistry,
18071808
device_registry: dr.DeviceRegistry,
1809+
use_satellite_entity: bool,
18081810
) -> None:
18091811
"""Test that acknowledge sound is played when targets are in the same area."""
18101812
area_1 = area_registry.async_get_or_create("area_1")
@@ -1819,12 +1821,16 @@ async def test_acknowledge(
18191821

18201822
entry = MockConfigEntry()
18211823
entry.add_to_hass(hass)
1822-
satellite = device_registry.async_get_or_create(
1824+
1825+
satellite = entity_registry.async_get_or_create("assist_satellite", "test", "1234")
1826+
entity_registry.async_update_entity(satellite.entity_id, area_id=area_1.id)
1827+
1828+
satellite_device = device_registry.async_get_or_create(
18231829
config_entry_id=entry.entry_id,
18241830
connections=set(),
18251831
identifiers={("demo", "id-1234")},
18261832
)
1827-
device_registry.async_update_device(satellite.id, area_id=area_1.id)
1833+
device_registry.async_update_device(satellite_device.id, area_id=area_1.id)
18281834

18291835
events: list[assist_pipeline.PipelineEvent] = []
18301836
turn_on = async_mock_service(hass, "light", "turn_on")
@@ -1837,7 +1843,8 @@ async def _run(text: str) -> None:
18371843
pipeline_input = assist_pipeline.pipeline.PipelineInput(
18381844
intent_input=text,
18391845
session=mock_chat_session,
1840-
device_id=satellite.id,
1846+
satellite_id=satellite.entity_id if use_satellite_entity else None,
1847+
device_id=satellite_device.id if not use_satellite_entity else None,
18411848
run=assist_pipeline.pipeline.PipelineRun(
18421849
hass,
18431850
context=Context(),
@@ -1889,7 +1896,8 @@ def _reset() -> None:
18891896
)
18901897

18911898
# 3. Remove satellite device area
1892-
device_registry.async_update_device(satellite.id, area_id=None)
1899+
entity_registry.async_update_entity(satellite.entity_id, area_id=None)
1900+
device_registry.async_update_device(satellite_device.id, area_id=None)
18931901

18941902
_reset()
18951903
await _run("turn on light 1")
@@ -1900,7 +1908,8 @@ def _reset() -> None:
19001908
assert len(turn_on) == 1
19011909

19021910
# Restore
1903-
device_registry.async_update_device(satellite.id, area_id=area_1.id)
1911+
entity_registry.async_update_entity(satellite.entity_id, area_id=area_1.id)
1912+
device_registry.async_update_device(satellite_device.id, area_id=area_1.id)
19041913

19051914
# 4. Check device area instead of entity area
19061915
light_device = device_registry.async_get_or_create(

0 commit comments

Comments
 (0)