Skip to content

Commit 2281a37

Browse files
authored
Merge pull request #236 from fersingb/Websocket
Added backward compatibility in the websocket branch
2 parents 7b088ec + 16c0f4b commit 2281a37

File tree

5 files changed

+224
-37
lines changed

5 files changed

+224
-37
lines changed

pyhilo/api.py

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
API_NOTIFICATIONS_ENDPOINT,
2828
API_REGISTRATION_ENDPOINT,
2929
API_REGISTRATION_HEADERS,
30-
AUTOMATION_DEVICEHUB_ENDPOINT,
30+
AUTOMATION_CHALLENGE_ENDPOINT,
3131
DEFAULT_STATE_FILE,
3232
DEFAULT_USER_AGENT,
3333
FB_APP_ID,
@@ -51,7 +51,7 @@
5151
get_state,
5252
set_state,
5353
)
54-
from pyhilo.websocket import WebsocketClient
54+
from pyhilo.websocket import WebsocketClient, WebsocketManager
5555

5656

5757
class API:
@@ -81,9 +81,17 @@ def __init__(
8181
self.device_attributes = get_device_attributes()
8282
self.session: ClientSession = session
8383
self._oauth_session = oauth_session
84+
self.websocket_devices: WebsocketClient
85+
# Backward compatibility during transition to websocket for challenges. Currently the HA Hilo integration
86+
# uses the .websocket attribute. Re-added this attribute and point to the same object as websocket_devices.
87+
# Should be removed once the transition to the challenge websocket is completed everywhere.
8488
self.websocket: WebsocketClient
89+
self.websocket_challenges: WebsocketClient
8590
self.log_traces = log_traces
8691
self._get_device_callbacks: list[Callable[..., Any]] = []
92+
self.ws_url: str = ""
93+
self.ws_token: str = ""
94+
self.endpoint: str = ""
8795

8896
@classmethod
8997
async def async_create(
@@ -132,6 +140,9 @@ async def async_get_access_token(self) -> str:
132140
if not self._oauth_session.valid_token:
133141
await self._oauth_session.async_ensure_token_valid()
134142

143+
access_token = str(self._oauth_session.token["access_token"])
144+
LOG.debug(f"ic-dev21 access token is {access_token}")
145+
135146
return str(self._oauth_session.token["access_token"])
136147

137148
def dev_atts(
@@ -216,17 +227,24 @@ async def _async_request(
216227
:rtype: dict[str, Any]
217228
"""
218229
kwargs.setdefault("headers", self.headers)
230+
access_token = await self.async_get_access_token()
231+
219232
if endpoint.startswith(API_REGISTRATION_ENDPOINT):
220233
kwargs["headers"] = {**kwargs["headers"], **API_REGISTRATION_HEADERS}
221234
if endpoint.startswith(FB_INSTALL_ENDPOINT):
222235
kwargs["headers"] = {**kwargs["headers"], **FB_INSTALL_HEADERS}
223236
if endpoint.startswith(ANDROID_CLIENT_ENDPOINT):
224237
kwargs["headers"] = {**kwargs["headers"], **ANDROID_CLIENT_HEADERS}
225238
if host == API_HOSTNAME:
226-
access_token = await self.async_get_access_token()
227239
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
228240
kwargs["headers"]["Host"] = host
229241

242+
# ic-dev21 trying Leicas suggestion
243+
if endpoint.startswith(AUTOMATION_CHALLENGE_ENDPOINT):
244+
# remove Ocp-Apim-Subscription-Key header to avoid 401 error
245+
kwargs["headers"].pop("Ocp-Apim-Subscription-Key", None)
246+
kwargs["headers"]["authorization"] = f"Bearer {access_token}"
247+
230248
data: dict[str, Any] = {}
231249
url = parse.urljoin(f"https://{host}", endpoint)
232250
if self.log_traces:
@@ -303,8 +321,9 @@ async def _async_handle_on_backoff(self, _: dict[str, Any]) -> None:
303321
LOG.info(
304322
"401 detected on websocket, refreshing websocket token. Old url: {self.ws_url} Old Token: {self.ws_token}"
305323
)
324+
LOG.info(f"401 detected on {err.request_info.url}")
306325
async with self._backoff_refresh_lock_ws:
307-
(self.ws_url, self.ws_token) = await self.post_devicehub_negociate()
326+
await self.refresh_ws_token()
308327
await self.get_websocket_params()
309328
return
310329

@@ -354,30 +373,26 @@ async def _async_post_init(self) -> None:
354373
LOG.debug("Websocket postinit")
355374
await self._get_fid()
356375
await self._get_device_token()
357-
await self.refresh_ws_token()
358-
self.websocket = WebsocketClient(self)
359376

360-
async def refresh_ws_token(self) -> None:
361-
(self.ws_url, self.ws_token) = await self.post_devicehub_negociate()
362-
await self.get_websocket_params()
363-
364-
async def post_devicehub_negociate(self) -> tuple[str, str]:
365-
LOG.debug("Getting websocket url")
366-
url = f"{AUTOMATION_DEVICEHUB_ENDPOINT}/negotiate"
367-
LOG.debug(f"devicehub URL is {url}")
368-
resp = await self.async_request("post", url)
369-
ws_url = resp.get("url")
370-
ws_token = resp.get("accessToken")
371-
LOG.debug("Calling set_state devicehub_negotiate")
372-
await set_state(
373-
self._state_yaml,
374-
"websocket",
375-
{
376-
"url": ws_url,
377-
"token": ws_token,
378-
},
377+
# Initialize WebsocketManager ic-dev21
378+
self.websocket_manager = WebsocketManager(
379+
self.session, self.async_request, self._state_yaml, set_state
379380
)
380-
return (ws_url, ws_token)
381+
await self.websocket_manager.initialize_websockets()
382+
383+
# Create both websocket clients
384+
# ic-dev21 need to work on this as it can't lint as is, may need to
385+
# instantiate differently
386+
self.websocket_devices = WebsocketClient(self.websocket_manager.devicehub)
387+
388+
# For backward compatibility during the transition to challengehub websocket
389+
self.websocket = self.websocket_devices
390+
self.websocket_challenges = WebsocketClient(self.websocket_manager.challengehub)
391+
392+
async def refresh_ws_token(self) -> None:
393+
"""Refresh the websocket token."""
394+
await self.websocket_manager.refresh_token(self.websocket_manager.devicehub)
395+
await self.websocket_manager.refresh_token(self.websocket_manager.challengehub)
381396

382397
async def get_websocket_params(self) -> None:
383398
uri = parse.urlparse(self.ws_url)

pyhilo/const.py

100644100755
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@
4242

4343
# Automation server constant
4444
AUTOMATION_DEVICEHUB_ENDPOINT: Final = "/DeviceHub"
45+
AUTOMATION_CHALLENGE_ENDPOINT: Final = "/ChallengeHub"
46+
4547

4648
# Request constants
4749
DEFAULT_USER_AGENT: Final = f"PyHilo/{PYHILO_VERSION} HomeAssistant/{homeassistant.core.__version__} aiohttp/{aiohttp.__version__} Python/{platform.python_version()}"

pyhilo/event.py

100644100755
Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
"""Event object """
22
from datetime import datetime, timedelta, timezone
3+
import logging
34
import re
45
from typing import Any, cast
56

67
from pyhilo.util import camel_to_snake, from_utc_timestamp
78

9+
LOG = logging.getLogger(__package__)
10+
811

912
class Event:
1013
setting_deadline: datetime
@@ -126,9 +129,12 @@ def current_phase_times(self) -> dict[str, datetime]:
126129
@property
127130
def state(self) -> str:
128131
now = datetime.now(self.preheat_start.tzinfo)
129-
if self.pre_cold_start <= now < self.pre_cold_end:
132+
if self.pre_cold_start and self.pre_cold_start <= now < self.pre_cold_end:
130133
return "pre_cold"
131-
elif self.appreciation_start <= now < self.appreciation_end:
134+
elif (
135+
self.appreciation_start
136+
and self.appreciation_start <= now < self.appreciation_end
137+
):
132138
return "appreciation"
133139
elif self.preheat_start > now:
134140
return "scheduled"
@@ -138,9 +144,12 @@ def state(self) -> str:
138144
return "reduction"
139145
elif self.recovery_start <= now < self.recovery_end:
140146
return "recovery"
147+
elif now >= self.recovery_end + timedelta(minutes=5):
148+
return "off"
141149
elif now >= self.recovery_end:
142150
return "completed"
143151
elif self.progress:
144152
return self.progress
153+
145154
else:
146155
return "unknown"

pyhilo/util/__init__.py

100644100755
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Define utility modules."""
22
import asyncio
33
from datetime import datetime, timedelta
4+
import logging
45
import re
56
from typing import Any, Callable
67

@@ -9,6 +10,9 @@
910

1011
from pyhilo.const import LOG # noqa: F401
1112

13+
LOG = logging.getLogger(__package__)
14+
15+
1216
CAMEL_REX_1 = re.compile("(.)([A-Z][a-z]+)")
1317
CAMEL_REX_2 = re.compile("([a-z0-9])([A-Z])")
1418

@@ -35,7 +39,11 @@ def snake_to_camel(string: str) -> str:
3539
def from_utc_timestamp(date_string: str) -> datetime:
3640
from_zone = tz.tzutc()
3741
to_zone = tz.tzlocal()
38-
return parse(date_string).replace(tzinfo=from_zone).astimezone(to_zone)
42+
dt = parse(date_string)
43+
if dt.tzinfo is None: # Only replace tzinfo if not already set
44+
dt = dt.replace(tzinfo=from_zone)
45+
output = dt.astimezone(to_zone)
46+
return output
3947

4048

4149
def time_diff(ts1: datetime, ts2: datetime) -> timedelta:

0 commit comments

Comments
 (0)