Skip to content

Commit 4a45844

Browse files
committed
reauth
1 parent ba56ce4 commit 4a45844

File tree

2 files changed

+78
-93
lines changed

2 files changed

+78
-93
lines changed

bring_api/bring.py

Lines changed: 76 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,14 @@
11
"""Bring api implementation."""
22

33
import asyncio
4-
from collections.abc import AsyncGenerator
5-
from contextlib import asynccontextmanager
64
from http import HTTPStatus
75
from itertools import chain
86
import json
97
from json import JSONDecodeError
108
import logging
119
import os
1210
import time
11+
from typing import Any
1312

1413
import aiohttp
1514
from mashumaro.exceptions import MissingField
@@ -82,70 +81,65 @@ def __init__(
8281
self.headers = DEFAULT_HEADERS.copy()
8382

8483
self.loop = asyncio.get_running_loop()
85-
self.refresh_token = ""
86-
self.__expires_in: int = 0
84+
self.__refresh_token: str | None = None
85+
self.__access_token_expires_at: float | None = None
86+
self._etag: dict[str, str] = {}
87+
self._site_cache: dict[str, str] = {}
8788

8889
@property
89-
def expires_in(self) -> int:
90+
def _expires_at(self) -> float | None:
9091
"""Refresh token expiration."""
91-
return max(0, self.__expires_in - int(time.time()))
92+
return self.__access_token_expires_at
9293

93-
@expires_in.setter
94-
def expires_in(self, expires_in: int | str) -> None:
95-
self.__expires_in = int(time.time()) + int(expires_in)
94+
@_expires_at.setter
95+
def _expires_at(self, expires_in: int) -> None:
96+
self.__access_token_expires_at = time.time() + expires_in
97+
98+
@property
99+
def _token_expired(self) -> bool:
100+
"""True if access token expired."""
101+
102+
return (
103+
self.__access_token_expires_at is None
104+
or self.__access_token_expires_at < time.time()
105+
)
96106

97-
@asynccontextmanager
98107
async def _request(
99-
self, method: str, url: URL, retry: bool = True, **kwargs
100-
) -> AsyncGenerator[aiohttp.ClientResponse]:
108+
self, method: str, url: URL, retry: bool = False, **kwargs: Any
109+
) -> str:
101110
"""Handle request and ensure valid auth token."""
102-
103-
if not self.expires_in and retry:
111+
headers = self.headers.copy()
112+
if (self._token_expired or retry) and self.__refresh_token:
104113
await self.retrieve_new_access_token()
105114

115+
if method == "get" and (etag := self._etag.get(str(url))) and not retry:
116+
headers["If-None-Match"] = etag
117+
106118
try:
107-
async with self._session.request(
108-
method, url, headers=self.headers, **kwargs
109-
) as r:
110-
_LOGGER.debug(
111-
"Response from %s [%s]: %s", url, r.status, await r.text()
112-
)
119+
r = await self._session.request(method, url, headers=headers, **kwargs)
120+
_LOGGER.debug("Response from %s [%s]: %s", url, r.status, await r.text())
113121

114-
if r.status == HTTPStatus.UNAUTHORIZED:
115-
try:
116-
errmsg = BringErrorResponse.from_json(await r.text())
117-
except MissingField as e:
118-
raise BringParseException(
119-
f"Failed to parse response: {str(e)} "
120-
"This is likely a bug. Please report it at: https://github.com/miaucl/bring-api/issues",
121-
) from e
122-
except (JSONDecodeError, aiohttp.ClientError):
123-
_LOGGER.debug(
124-
"Exception: Cannot parse request response", exc_info=True
125-
)
126-
else:
127-
_LOGGER.debug("Exception: %s", repr(errmsg))
128-
if retry:
129-
try:
130-
await self.retrieve_new_access_token()
131-
except BringAuthException as e:
132-
raise BringAuthException from e
133-
else:
134-
async with self._request(
135-
method, url, False, **kwargs
136-
) as r:
137-
yield r
138-
else:
139-
raise BringAuthException
122+
if r.status == HTTPStatus.UNAUTHORIZED:
123+
try:
124+
errmsg = BringErrorResponse.from_json(await r.text())
125+
except MissingField as e:
126+
raise BringMissingFieldException(e) from e
127+
except (JSONDecodeError, aiohttp.ClientError):
128+
_LOGGER.debug(
129+
"Exception: Cannot parse error response", exc_info=True
130+
)
131+
else:
132+
_LOGGER.debug("Exception: Authentication failed: %s", repr(errmsg))
133+
if not retry:
134+
return await self._request(method, url, True, **kwargs)
140135

141-
r.raise_for_status()
136+
raise BringAuthException(
137+
"Loading list items failed due to authorization failure, "
138+
"the authorization token is invalid or expired."
139+
)
140+
141+
r.raise_for_status()
142142

143-
yield r
144-
except BringAuthException as e:
145-
raise BringAuthException(
146-
"Login failed due to authorization failure, "
147-
"please check your email and password."
148-
) from e
149143
except aiohttp.ClientResponseError as e:
150144
_LOGGER.debug("Exception: %s", repr(e), exc_info=True)
151145
raise BringRequestException(
@@ -161,6 +155,21 @@ async def _request(
161155
raise BringRequestException(
162156
"Request failed due to client connection error."
163157
) from e
158+
else:
159+
if r.status == HTTPStatus.NOT_MODIFIED and etag:
160+
try:
161+
return self._site_cache[etag]
162+
except KeyError:
163+
self._etag.pop(str(url), None)
164+
return await self._request(method, url, True, **kwargs)
165+
166+
body = await r.text()
167+
168+
if etag := r.headers.get("etag"):
169+
self._etag[str(url)] = etag
170+
self._site_cache[etag] = body
171+
172+
return body
164173

165174
async def login(self) -> BringAuthResponse:
166175
"""Try to login.
@@ -246,8 +255,8 @@ async def login(self) -> BringAuthResponse:
246255
self.headers["X-BRING-USER-UUID"] = self.uuid
247256
self.headers["X-BRING-PUBLIC-USER-UUID"] = self.public_uuid
248257
self.headers["Authorization"] = f"{data.token_type} {data.access_token}"
249-
self.refresh_token = data.refresh_token
250-
self.expires_in = data.expires_in
258+
self.__refresh_token = data.refresh_token
259+
self._expires_at = data.expires_in
251260

252261
locale = (await self.get_user_account()).userLocale
253262
self.headers["X-BRING-COUNTRY"] = locale.country
@@ -298,43 +307,17 @@ async def load_lists(self) -> BringListResponse:
298307
If the request fails due to invalid or expired authorization token.
299308
300309
"""
301-
try:
302-
url = self.url / "bringusers" / self.uuid / "lists"
303-
async with self._request("get", url) as r:
304-
if r.status == HTTPStatus.UNAUTHORIZED:
305-
try:
306-
errmsg = BringErrorResponse.from_json(await r.text())
307-
except (JSONDecodeError, aiohttp.ClientError):
308-
_LOGGER.debug(
309-
"Exception: Cannot parse request response:", exc_info=True
310-
)
311-
else:
312-
_LOGGER.debug("Exception: Cannot get lists: %s", errmsg.message)
313-
raise BringAuthException(
314-
"Loading lists failed due to authorization failure, "
315-
"the authorization token is invalid or expired."
316-
)
317310

318-
r.raise_for_status()
319-
320-
try:
321-
return BringListResponse.from_json(await r.text())
322-
except MissingField as e:
323-
raise BringMissingFieldException(e) from e
324-
except JSONDecodeError as e:
325-
_LOGGER.debug("Exception: Cannot get lists:", exc_info=True)
326-
raise BringParseException(
327-
"Loading lists failed during parsing of request response."
328-
) from e
329-
except TimeoutError as e:
330-
_LOGGER.debug("Exception: Cannot get lists:", exc_info=True)
331-
raise BringRequestException(
332-
"Loading lists failed due to connection timeout."
333-
) from e
334-
except aiohttp.ClientError as e:
311+
url = self.url / "bringusers" / self.uuid / "lists"
312+
r = await self._request("get", url)
313+
try:
314+
return BringListResponse.from_json(r)
315+
except MissingField as e:
316+
raise BringMissingFieldException(e) from e
317+
except JSONDecodeError as e:
335318
_LOGGER.debug("Exception: Cannot get lists:", exc_info=True)
336-
raise BringRequestException(
337-
"Loading lists failed due to request exception."
319+
raise BringParseException(
320+
"Loading lists failed during parsing of request response."
338321
) from e
339322

340323
async def get_list(self, list_uuid: str) -> BringItemsResponse:
@@ -1430,7 +1413,8 @@ async def retrieve_new_access_token(
14301413
If the request fails due to invalid or expired refresh token.
14311414
14321415
"""
1433-
refresh_token = refresh_token or self.refresh_token
1416+
if not (refresh_token := refresh_token or self.__refresh_token):
1417+
raise BringAuthException("Refresh token not found. Login required.")
14341418

14351419
user_data = {"grant_type": "refresh_token", "refresh_token": refresh_token}
14361420
try:
@@ -1489,7 +1473,7 @@ async def retrieve_new_access_token(
14891473
) from e
14901474

14911475
self.headers["Authorization"] = f"{data.token_type} {data.access_token}"
1492-
self.expires_in = data.expires_in
1476+
self._expires_at = data.expires_in
14931477

14941478
return data
14951479

tests/conftest.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,8 @@ async def aiohttp_client_session() -> AsyncGenerator[aiohttp.ClientSession]:
5858
async def bring_api_client(session: aiohttp.ClientSession) -> Bring:
5959
"""Create Bring instance."""
6060
bring = Bring(session, "EMAIL", "PASSWORD")
61-
bring.expires_in = 604799
61+
bring._expires_at = 604799
62+
6263
return bring
6364

6465

0 commit comments

Comments
 (0)