|
11 | 11 | import voluptuous as vol |
12 | 12 |
|
13 | 13 | from homeassistant.components import conversation |
| 14 | +from homeassistant.components.cloud.const import AI_TASK_ENTITY_UNIQUE_ID, DOMAIN |
14 | 15 | from homeassistant.components.cloud.entity import ( |
15 | 16 | BaseCloudLLMEntity, |
16 | 17 | _convert_content_to_param, |
17 | 18 | _format_structured_output, |
18 | 19 | ) |
19 | 20 | from homeassistant.core import HomeAssistant |
20 | 21 | from homeassistant.exceptions import HomeAssistantError |
21 | | -from homeassistant.helpers import llm, selector |
| 22 | +from homeassistant.helpers import entity_registry as er, llm, selector |
| 23 | +from homeassistant.setup import async_setup_component |
22 | 24 |
|
23 | 25 | from tests.common import MockConfigEntry |
24 | 26 |
|
@@ -219,3 +221,66 @@ async def test_prepare_chat_for_generation_passes_messages_through( |
219 | 221 |
|
220 | 222 | assert response["messages"] == messages |
221 | 223 | assert response["conversation_id"] == "conversation-id" |
| 224 | + |
| 225 | + |
| 226 | +async def test_async_handle_chat_log_service_sets_structured_output_non_strict( |
| 227 | + hass: HomeAssistant, |
| 228 | + cloud: MagicMock, |
| 229 | + entity_registry: er.EntityRegistry, |
| 230 | + mock_cloud_login: None, |
| 231 | +) -> None: |
| 232 | + """Ensure structured output requests always disable strict validation via service.""" |
| 233 | + assert await async_setup_component(hass, DOMAIN, {}) |
| 234 | + await hass.async_block_till_done() |
| 235 | + |
| 236 | + on_start_callback = cloud.register_on_start.call_args[0][0] |
| 237 | + await on_start_callback() |
| 238 | + await hass.async_block_till_done() |
| 239 | + |
| 240 | + entity_id = entity_registry.async_get_entity_id( |
| 241 | + "ai_task", DOMAIN, AI_TASK_ENTITY_UNIQUE_ID |
| 242 | + ) |
| 243 | + assert entity_id is not None |
| 244 | + |
| 245 | + async def _empty_stream(): |
| 246 | + return |
| 247 | + |
| 248 | + async def _fake_delta_stream( |
| 249 | + self: conversation.ChatLog, |
| 250 | + agent_id: str, |
| 251 | + stream, |
| 252 | + ): |
| 253 | + content = conversation.AssistantContent( |
| 254 | + agent_id=agent_id, content='{"value": "ok"}' |
| 255 | + ) |
| 256 | + self.async_add_assistant_content_without_tools(content) |
| 257 | + yield content |
| 258 | + |
| 259 | + cloud.llm.async_generate_data = AsyncMock(return_value=_empty_stream()) |
| 260 | + |
| 261 | + with patch( |
| 262 | + "homeassistant.components.conversation.chat_log.ChatLog.async_add_delta_content_stream", |
| 263 | + _fake_delta_stream, |
| 264 | + ): |
| 265 | + await hass.services.async_call( |
| 266 | + "ai_task", |
| 267 | + "generate_data", |
| 268 | + { |
| 269 | + "entity_id": entity_id, |
| 270 | + "task_name": "Device Report", |
| 271 | + "instructions": "Provide value.", |
| 272 | + "structure": { |
| 273 | + "value": { |
| 274 | + "selector": {"text": None}, |
| 275 | + "required": True, |
| 276 | + } |
| 277 | + }, |
| 278 | + }, |
| 279 | + blocking=True, |
| 280 | + return_response=True, |
| 281 | + ) |
| 282 | + |
| 283 | + cloud.llm.async_generate_data.assert_awaited_once() |
| 284 | + _, kwargs = cloud.llm.async_generate_data.call_args |
| 285 | + |
| 286 | + assert kwargs["response_format"]["json_schema"]["strict"] is False |
0 commit comments