|
27 | 27 | API_NOTIFICATIONS_ENDPOINT, |
28 | 28 | API_REGISTRATION_ENDPOINT, |
29 | 29 | API_REGISTRATION_HEADERS, |
30 | | - AUTOMATION_DEVICEHUB_ENDPOINT, |
| 30 | + AUTOMATION_CHALLENGE_ENDPOINT, |
31 | 31 | DEFAULT_STATE_FILE, |
32 | 32 | DEFAULT_USER_AGENT, |
33 | 33 | FB_APP_ID, |
|
51 | 51 | get_state, |
52 | 52 | set_state, |
53 | 53 | ) |
54 | | -from pyhilo.websocket import WebsocketClient |
| 54 | +from pyhilo.websocket import WebsocketClient, WebsocketManager |
55 | 55 |
|
56 | 56 |
|
57 | 57 | class API: |
@@ -81,9 +81,17 @@ def __init__( |
81 | 81 | self.device_attributes = get_device_attributes() |
82 | 82 | self.session: ClientSession = session |
83 | 83 | self._oauth_session = oauth_session |
| 84 | + self.websocket_devices: WebsocketClient |
| 85 | + # Backward compatibility during transition to websocket for challenges. Currently the HA Hilo integration |
| 86 | + # uses the .websocket attribute. Re-added this attribute and point to the same object as websocket_devices. |
| 87 | + # Should be removed once the transition to the challenge websocket is completed everywhere. |
84 | 88 | self.websocket: WebsocketClient |
| 89 | + self.websocket_challenges: WebsocketClient |
85 | 90 | self.log_traces = log_traces |
86 | 91 | self._get_device_callbacks: list[Callable[..., Any]] = [] |
| 92 | + self.ws_url: str = "" |
| 93 | + self.ws_token: str = "" |
| 94 | + self.endpoint: str = "" |
87 | 95 |
|
88 | 96 | @classmethod |
89 | 97 | async def async_create( |
@@ -132,6 +140,9 @@ async def async_get_access_token(self) -> str: |
132 | 140 | if not self._oauth_session.valid_token: |
133 | 141 | await self._oauth_session.async_ensure_token_valid() |
134 | 142 |
|
| 143 | + access_token = str(self._oauth_session.token["access_token"]) |
| 144 | + LOG.debug(f"ic-dev21 access token is {access_token}") |
| 145 | + |
135 | 146 | return str(self._oauth_session.token["access_token"]) |
136 | 147 |
|
137 | 148 | def dev_atts( |
@@ -216,17 +227,24 @@ async def _async_request( |
216 | 227 | :rtype: dict[str, Any] |
217 | 228 | """ |
218 | 229 | kwargs.setdefault("headers", self.headers) |
| 230 | + access_token = await self.async_get_access_token() |
| 231 | + |
219 | 232 | if endpoint.startswith(API_REGISTRATION_ENDPOINT): |
220 | 233 | kwargs["headers"] = {**kwargs["headers"], **API_REGISTRATION_HEADERS} |
221 | 234 | if endpoint.startswith(FB_INSTALL_ENDPOINT): |
222 | 235 | kwargs["headers"] = {**kwargs["headers"], **FB_INSTALL_HEADERS} |
223 | 236 | if endpoint.startswith(ANDROID_CLIENT_ENDPOINT): |
224 | 237 | kwargs["headers"] = {**kwargs["headers"], **ANDROID_CLIENT_HEADERS} |
225 | 238 | if host == API_HOSTNAME: |
226 | | - access_token = await self.async_get_access_token() |
227 | 239 | kwargs["headers"]["authorization"] = f"Bearer {access_token}" |
228 | 240 | kwargs["headers"]["Host"] = host |
229 | 241 |
|
| 242 | + # ic-dev21 trying Leicas suggestion |
| 243 | + if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT): |
| 244 | + # remove Ocp-Apim-Subscription-Key header to avoid 401 error |
| 245 | + kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None) |
| 246 | + kwargs["headers"]["authorization"] = f"Bearer {access_token}" |
| 247 | + |
230 | 248 | data: dict[str, Any] = {} |
231 | 249 | url = parse.urljoin(f"https://{host}", endpoint) |
232 | 250 | if self.log_traces: |
@@ -303,8 +321,9 @@ async def _async_handle_on_backoff(self, _: dict[str, Any]) -> None: |
303 | 321 | LOG.info( |
304 | 322 | "401 detected on websocket, refreshing websocket token. Old url: {self.ws_url} Old Token: {self.ws_token}" |
305 | 323 | ) |
| 324 | + LOG.info(f"401 detected on {err.request_info.url}") |
306 | 325 | async with self._backoff_refresh_lock_ws: |
307 | | - (self.ws_url, self.ws_token) = await self.post_devicehub_negociate() |
| 326 | + await self.refresh_ws_token() |
308 | 327 | await self.get_websocket_params() |
309 | 328 | return |
310 | 329 |
|
@@ -354,30 +373,26 @@ async def _async_post_init(self) -> None: |
354 | 373 | LOG.debug("Websocket postinit") |
355 | 374 | await self._get_fid() |
356 | 375 | await self._get_device_token() |
357 | | - await self.refresh_ws_token() |
358 | | - self.websocket = WebsocketClient(self) |
359 | 376 |
|
360 | | - async def refresh_ws_token(self) -> None: |
361 | | - (self.ws_url, self.ws_token) = await self.post_devicehub_negociate() |
362 | | - await self.get_websocket_params() |
363 | | - |
364 | | - async def post_devicehub_negociate(self) -> tuple[str, str]: |
365 | | - LOG.debug("Getting websocket url") |
366 | | - url = f"{AUTOMATION_DEVICEHUB_ENDPOINT}/negotiate" |
367 | | - LOG.debug(f"devicehub URL is {url}") |
368 | | - resp = await self.async_request("post", url) |
369 | | - ws_url = resp.get("url") |
370 | | - ws_token = resp.get("accessToken") |
371 | | - LOG.debug("Calling set_state devicehub_negotiate") |
372 | | - await set_state( |
373 | | - self._state_yaml, |
374 | | - "websocket", |
375 | | - { |
376 | | - "url": ws_url, |
377 | | - "token": ws_token, |
378 | | - }, |
| 377 | + # Initialize WebsocketManager ic-dev21 |
| 378 | + self.websocket_manager = WebsocketManager( |
| 379 | + self.session, self.async_request, self._state_yaml, set_state |
379 | 380 | ) |
380 | | - return (ws_url, ws_token) |
| 381 | + await self.websocket_manager.initialize_websockets() |
| 382 | + |
| 383 | + # Create both websocket clients |
| 384 | + # ic-dev21 need to work on this as it can't lint as is, may need to |
| 385 | + # instantiate differently |
| 386 | + self.websocket_devices = WebsocketClient(self.websocket_manager.devicehub) |
| 387 | + |
| 388 | + # For backward compatibility during the transition to challengehub websocket |
| 389 | + self.websocket = self.websocket_devices |
| 390 | + self.websocket_challenges = WebsocketClient(self.websocket_manager.challengehub) |
| 391 | + |
| 392 | + async def refresh_ws_token(self) -> None: |
| 393 | + """Refresh the websocket token.""" |
| 394 | + await self.websocket_manager.refresh_token(self.websocket_manager.devicehub) |
| 395 | + await self.websocket_manager.refresh_token(self.websocket_manager.challengehub) |
381 | 396 |
|
382 | 397 | async def get_websocket_params(self) -> None: |
383 | 398 | uri = parse.urlparse(self.ws_url) |
|
0 commit comments