Skip to content

Commit f91e3ad

Browse files
committed
Make access token awaitable
1 parent bb80842 commit f91e3ad

File tree

2 files changed

+47
-27
lines changed

2 files changed

+47
-27
lines changed

teslemetry_stream/stream.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
from collections.abc import Callable
2-
from typing import Any
3-
import aiohttp
41
import asyncio
52
import json
63
import logging
74
from datetime import datetime, timezone
5+
from typing import Any, Awaitable, Callable
6+
7+
import aiohttp
88

9-
from .vehicle import TeslemetryStreamVehicle
109
from .exception import TeslemetryStreamEnded
10+
from .vehicle import TeslemetryStreamVehicle
1111

1212
LOGGER = logging.getLogger(__package__)
1313

14+
1415
class TeslemetryStream:
1516
"""Teslemetry Stream Client"""
1617

@@ -19,7 +20,7 @@ class TeslemetryStream:
1920
def __init__(
2021
self,
2122
session: aiohttp.ClientSession,
22-
access_token: str,
23+
access_token: str | Callable[[], Awaitable[str]],
2324
server: str = "api.teslemetry.com",
2425
vin: str | None = None,
2526
parse_timestamp: bool = False,
@@ -41,19 +42,30 @@ def __init__(
4142
self.active: bool = False
4243
self.server = server
4344
self.vin = vin
44-
self._listeners: dict[Callable, tuple[Callable[[dict[str,Any]],None], dict | None]] = {}
45-
self._connection_listeners: dict[Callable, Callable[[bool],None]] = {}
45+
self._listeners: dict[
46+
Callable, tuple[Callable[[dict[str, Any]], None], dict | None]
47+
] = {}
48+
self._connection_listeners: dict[Callable, Callable[[bool], None]] = {}
4649
self._session = session
47-
self._headers = {"Authorization": f"Bearer {access_token}", "X-Library": "python teslemetry-stream"}
50+
self.access_token = access_token
4851
self.parse_timestamp = parse_timestamp
4952
self.manual = manual
5053
self.retries: int = 0
5154
self.vehicles: dict[str, TeslemetryStreamVehicle] = {}
5255

53-
if(self.vin):
56+
if self.vin:
5457
self.vehicle: TeslemetryStreamVehicle = self.get_vehicle(self.vin)
5558
self.vehicles[self.vin] = self.vehicle
5659

60+
async def headers(self) -> dict[str, str]:
61+
if callable(self.access_token):
62+
access_token = await self.access_token()
63+
else:
64+
access_token = self.access_token
65+
return {
66+
"Authorization": f"Bearer {access_token}",
67+
"X-Library": "python teslemetry-stream",
68+
}
5769

5870
def get_vehicle(self, vin: str) -> TeslemetryStreamVehicle:
5971
"""
@@ -83,20 +95,21 @@ async def get_config(self, vin: str | None = None) -> None:
8395
"""
8496
if not self.server:
8597
await self.find_server()
86-
if hasattr(self, 'vehicle'):
98+
if hasattr(self, "vehicle"):
8799
await self.vehicle.get_config()
88100

89101
async def find_server(self) -> None:
90102
"""
91103
Find the server using metadata.
92104
"""
105+
headers = await self.headers()
93106
req = await self._session.get(
94107
"https://api.teslemetry.com/api/metadata",
95-
headers=self._headers,
108+
headers=headers,
96109
raise_for_status=True,
97110
)
98111
response = await req.json()
99-
self.server = f"{response["region"].lower()}.teslemetry.com"
112+
self.server = f"{response['region'].lower()}.teslemetry.com"
100113

101114
async def update_fields(self, fields: dict, vin: str) -> dict:
102115
"""
@@ -106,9 +119,10 @@ async def update_fields(self, fields: dict, vin: str) -> dict:
106119
:param vin: Vehicle Identification Number.
107120
:return: Response JSON as a dictionary.
108121
"""
122+
headers = await self.headers()
109123
resp = await self._session.patch(
110124
f"https://api.teslemetry.com/api/config/{self.vin}",
111-
headers=self._headers,
125+
headers=headers,
112126
json={"fields": fields},
113127
raise_for_status=False,
114128
)
@@ -124,9 +138,10 @@ async def replace_fields(self, fields: dict, vin: str) -> dict:
124138
:param vin: Vehicle Identification Number.
125139
:return: Response JSON as a dictionary.
126140
"""
141+
headers = await self.headers()
127142
resp = await self._session.post(
128143
f"https://api.teslemetry.com/api/config/{self.vin}",
129-
headers=self._headers,
144+
headers=headers,
130145
json={"fields": fields},
131146
raise_for_status=False,
132147
)
@@ -154,6 +169,7 @@ def async_add_connection_listener(
154169
:param callback: Callback function to handle connection state changes.
155170
:return: Function to remove the listener.
156171
"""
172+
157173
def remove_listener() -> None:
158174
"""
159175
Remove connection listener.
@@ -181,14 +197,15 @@ async def connect(self) -> None:
181197
url = f"https://{self.server}/sse"
182198
if self.vin:
183199
url += f"/{self.vin}"
200+
headers = await self.headers()
184201
self._response = await self._session.get(
185202
url,
186-
headers=self._headers,
203+
headers=headers,
187204
raise_for_status=True,
188205
timeout=aiohttp.ClientTimeout(
189206
connect=5, sock_connect=5, sock_read=30, total=None
190207
),
191-
chunked=True
208+
chunked=True,
192209
)
193210
LOGGER.debug(
194211
"Connected to %s with status %s", self._response.url, self._response.status
@@ -223,7 +240,7 @@ def __aiter__(self):
223240
self.active = True
224241
return self
225242

226-
async def __anext__(self) -> dict:
243+
async def __anext__(self) -> dict[str, Any]:
227244
"""
228245
Return next event.
229246
@@ -262,7 +279,7 @@ async def __anext__(self) -> dict:
262279
except aiohttp.ClientError as error:
263280
LOGGER.warning("Client error: %s", repr(error))
264281
self.close()
265-
delay = min(2 ** self.retries, 600)
282+
delay = min(2**self.retries, 600)
266283
LOGGER.debug("Reconnecting in %s seconds", delay)
267284
await asyncio.sleep(delay)
268285
self.retries += 1
@@ -272,7 +289,6 @@ async def __anext__(self) -> dict:
272289
LOGGER.debug("Reconnecting in %s seconds", 1)
273290
await asyncio.sleep(1)
274291

275-
276292
def async_add_listener(
277293
self, callback: Callable, filters: dict | None = None
278294
) -> Callable[[], None]:
@@ -316,16 +332,17 @@ async def listen(self):
316332
LOGGER.error("Uncaught error in listener: %s", error)
317333
LOGGER.debug("Listen has finished")
318334

319-
def listen_Credits(self, callback: Callable[[dict[str, str | int]], None]) -> Callable[[], None]:
335+
def listen_Credits(
336+
self, callback: Callable[[dict[str, str | int]], None]
337+
) -> Callable[[], None]:
320338
"""
321339
Listen for credits update.
322340
323341
:param callback: Callback function to handle credits update.
324342
:return: Function to remove the listener.
325343
"""
326344
return self.async_add_listener(
327-
lambda x: callback(x["credits"]),
328-
{"credits": None}
345+
lambda x: callback(x["credits"]), {"credits": None}
329346
)
330347

331348
def listen_Balance(self, callback: Callable[[int], None]) -> Callable[[], None]:
@@ -336,10 +353,10 @@ def listen_Balance(self, callback: Callable[[int], None]) -> Callable[[], None]:
336353
:return: Function to remove the listener.
337354
"""
338355
return self.async_add_listener(
339-
lambda x: callback(x["credits"]["balance"]),
340-
{"credits": {"balance": None}}
356+
lambda x: callback(x["credits"]["balance"]), {"credits": {"balance": None}}
341357
)
342358

359+
343360
def recursive_match(dict1, dict2):
344361
"""
345362
Recursively match dict1 with dict2.

teslemetry_stream/vehicle.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,10 @@ def config(self) -> dict:
8484
async def get_config(self) -> None:
8585
"""Get the current configuration for the vehicle."""
8686

87+
headers = await self.stream.headers()
8788
req = await self.stream._session.get(
8889
f"https://api.teslemetry.com/api/config/{self.vin}",
89-
headers=self.stream._headers,
90+
headers=headers,
9091
raise_for_status=False,
9192
)
9293
if req.status == 200:
@@ -133,19 +134,21 @@ async def update_config(self, config: dict) -> None:
133134

134135
async def patch_config(self, config: dict) -> dict[str, str | dict]:
135136
"""Modify the configuration for the vehicle."""
137+
headers = await self.stream.headers()
136138
resp = await self.stream._session.patch(
137139
f"https://api.teslemetry.com/api/config/{self.vin}",
138-
headers=self.stream._headers,
140+
headers=headers,
139141
json=config,
140142
raise_for_status=False,
141143
)
142144
return await resp.json()
143145

144146
async def post_config(self, config: dict) -> dict[str, str | dict]:
145147
"""Overwrite the configuration for the vehicle."""
148+
headers = await self.stream.headers()
146149
resp = await self.stream._session.post(
147150
f"https://api.teslemetry.com/api/config/{self.vin}",
148-
headers=self.stream._headers,
151+
headers=headers,
149152
json=config,
150153
raise_for_status=False,
151154
)

0 commit comments

Comments
 (0)