|
1 | 1 | from __future__ import annotations |
2 | 2 |
|
3 | 3 | import platform |
| 4 | +import re |
4 | 5 | import warnings |
| 6 | +from dataclasses import dataclass |
5 | 7 | from datetime import datetime, timezone |
6 | 8 | from pathlib import Path |
7 | | -from typing import TypedDict |
| 9 | +from typing import TypedDict, cast |
8 | 10 | from urllib.parse import urljoin |
9 | 11 |
|
10 | 12 | import requests |
| 13 | +from rich.prompt import IntPrompt |
| 14 | +from typing_extensions import Self |
11 | 15 |
|
12 | 16 | from logfire.exceptions import LogfireConfigError |
13 | 17 |
|
14 | | -from .utils import UnexpectedResponse |
| 18 | +from .utils import UnexpectedResponse, read_toml_file |
15 | 19 |
|
16 | 20 | HOME_LOGFIRE = Path.home() / '.logfire' |
17 | 21 | """Folder used to store global configuration, and user tokens.""" |
18 | 22 | DEFAULT_FILE = HOME_LOGFIRE / 'default.toml' |
19 | 23 | """File used to store user tokens.""" |
20 | 24 |
|
21 | 25 |
|
| 26 | +PYDANTIC_LOGFIRE_TOKEN_PATTERN = re.compile( |
| 27 | + r'^(?P<safe_part>pylf_v(?P<version>[0-9]+)_(?P<region>[a-z]+)_)(?P<token>[a-zA-Z0-9]+)$' |
| 28 | +) |
| 29 | + |
| 30 | + |
| 31 | +class _RegionData(TypedDict): |
| 32 | + base_url: str |
| 33 | + gcp_region: str |
| 34 | + |
| 35 | + |
| 36 | +REGIONS: dict[str, _RegionData] = { |
| 37 | + 'us': { |
| 38 | + 'base_url': 'https://logfire-us.pydantic.dev', |
| 39 | + 'gcp_region': 'us-east4', |
| 40 | + }, |
| 41 | + 'eu': { |
| 42 | + 'base_url': 'https://logfire-eu.pydantic.dev', |
| 43 | + 'gcp_region': 'europe-west4', |
| 44 | + }, |
| 45 | +} |
| 46 | +"""The existing Logfire regions.""" |
| 47 | + |
| 48 | + |
22 | 49 | class UserTokenData(TypedDict): |
23 | 50 | """User token data.""" |
24 | 51 |
|
25 | 52 | token: str |
26 | 53 | expiration: str |
27 | 54 |
|
28 | 55 |
|
29 | | -class DefaultFile(TypedDict): |
30 | | - """Content of the default.toml file.""" |
| 56 | +class UserTokensFileData(TypedDict, total=False): |
| 57 | + """Content of the file containing the user tokens.""" |
31 | 58 |
|
32 | 59 | tokens: dict[str, UserTokenData] |
33 | 60 |
|
34 | 61 |
|
| 62 | +@dataclass |
| 63 | +class UserToken: |
| 64 | + """A user token.""" |
| 65 | + |
| 66 | + token: str |
| 67 | + base_url: str |
| 68 | + expiration: str |
| 69 | + |
| 70 | + @classmethod |
| 71 | + def from_user_token_data(cls, base_url: str, token: UserTokenData) -> Self: |
| 72 | + return cls( |
| 73 | + token=token['token'], |
| 74 | + base_url=base_url, |
| 75 | + expiration=token['expiration'], |
| 76 | + ) |
| 77 | + |
| 78 | + @property |
| 79 | + def is_expired(self) -> bool: |
| 80 | + """Whether the token is expired.""" |
| 81 | + return datetime.now(tz=timezone.utc) >= datetime.fromisoformat(self.expiration.rstrip('Z')).replace( |
| 82 | + tzinfo=timezone.utc |
| 83 | + ) |
| 84 | + |
| 85 | + def __str__(self) -> str: |
| 86 | + region = 'us' |
| 87 | + if match := PYDANTIC_LOGFIRE_TOKEN_PATTERN.match(self.token): |
| 88 | + region = match.group('region') |
| 89 | + if region not in REGIONS: |
| 90 | + region = 'us' |
| 91 | + |
| 92 | + token_repr = f'{region.upper()} ({self.base_url}) - ' |
| 93 | + if match: |
| 94 | + token_repr += match.group('safe_part') + match.group('token')[:5] |
| 95 | + else: |
| 96 | + token_repr += self.token[:5] |
| 97 | + token_repr += '****' |
| 98 | + return token_repr |
| 99 | + |
| 100 | + |
| 101 | +@dataclass |
| 102 | +class UserTokenCollection: |
| 103 | + """A collection of user tokens, read from a user tokens file. |
| 104 | +
|
| 105 | + Args: |
| 106 | + path: The path where the user tokens will be stored. If the path doesn't exist, |
| 107 | + an empty collection is created. Defaults to `~/.logfire/default.toml`. |
| 108 | + """ |
| 109 | + |
| 110 | + user_tokens: dict[str, UserToken] |
| 111 | + """A mapping between base URLs and user tokens.""" |
| 112 | + |
| 113 | + path: Path |
| 114 | + """The path where the user tokens are stored.""" |
| 115 | + |
| 116 | + def __init__(self, path: Path | None = None) -> None: |
| 117 | + # FIXME: we can't set the default value of `path` to `DEFAULT_FILE`, otherwise |
| 118 | + # `mock.patch()` doesn't work: |
| 119 | + self.path = path if path is not None else DEFAULT_FILE |
| 120 | + try: |
| 121 | + data = cast(UserTokensFileData, read_toml_file(self.path)) |
| 122 | + except FileNotFoundError: |
| 123 | + data: UserTokensFileData = {} |
| 124 | + self.user_tokens = {url: UserToken(base_url=url, **token) for url, token in data.get('tokens', {}).items()} |
| 125 | + |
| 126 | + def get_token(self, base_url: str | None = None) -> UserToken: |
| 127 | + """Get a user token from the collection. |
| 128 | +
|
| 129 | + Args: |
| 130 | + base_url: Only look for user tokens valid for this base URL. If not provided, |
| 131 | + all the tokens of the collection will be considered: if only one token is |
| 132 | + available, it will be used, otherwise the user will be prompted to choose |
| 133 | + a token. |
| 134 | +
|
| 135 | + Raises: |
| 136 | + LogfireConfigError: If no user token is found (no token matched the base URL, |
| 137 | + the collection is empty, or the selected token is expired). |
| 138 | + """ |
| 139 | + tokens_list = list(self.user_tokens.values()) |
| 140 | + |
| 141 | + if base_url is not None: |
| 142 | + token = self.user_tokens.get(base_url) |
| 143 | + if token is None: |
| 144 | + raise LogfireConfigError( |
| 145 | + f'No user token was found matching the {base_url} Logfire URL. ' |
| 146 | + 'Please run `logfire auth` to authenticate.' |
| 147 | + ) |
| 148 | + elif len(tokens_list) == 1: |
| 149 | + token = tokens_list[0] |
| 150 | + elif len(tokens_list) >= 2: |
| 151 | + choices_str = '\n'.join( |
| 152 | + f'{i}. {token} ({"expired" if token.is_expired else "valid"})' |
| 153 | + for i, token in enumerate(tokens_list, start=1) |
| 154 | + ) |
| 155 | + int_choice = IntPrompt.ask( |
| 156 | + f'Multiple user tokens found. Please select one:\n{choices_str}\n', |
| 157 | + choices=[str(i) for i in range(1, len(tokens_list) + 1)], |
| 158 | + ) |
| 159 | + token = tokens_list[int_choice - 1] |
| 160 | + else: # tokens_list == [] |
| 161 | + raise LogfireConfigError('No user tokens are available. Please run `logfire auth` to authenticate.') |
| 162 | + |
| 163 | + if token.is_expired: |
| 164 | + raise LogfireConfigError(f'User token {token} is expired. Please run `logfire auth` to authenticate.') |
| 165 | + return token |
| 166 | + |
| 167 | + def is_logged_in(self, base_url: str | None = None) -> bool: |
| 168 | + """Check whether the user token collection contains at least one valid user token. |
| 169 | +
|
| 170 | + Args: |
| 171 | + base_url: Only check for user tokens valid for this base URL. If not provided, |
| 172 | + all the tokens of the collection will be considered. |
| 173 | + """ |
| 174 | + if base_url is not None: |
| 175 | + tokens = (t for t in self.user_tokens.values() if t.base_url == base_url) |
| 176 | + else: |
| 177 | + tokens = self.user_tokens.values() |
| 178 | + return any(not t.is_expired for t in tokens) |
| 179 | + |
| 180 | + def add_token(self, base_url: str, token: UserTokenData) -> UserToken: |
| 181 | + """Add a user token to the collection.""" |
| 182 | + self.user_tokens[base_url] = user_token = UserToken.from_user_token_data(base_url, token) |
| 183 | + self._dump() |
| 184 | + return user_token |
| 185 | + |
| 186 | + def _dump(self) -> None: |
| 187 | + """Dump the user token collection as TOML to the provided path.""" |
| 188 | + # There's no standard library package to write TOML files, so we'll write it manually. |
| 189 | + with self.path.open('w') as f: |
| 190 | + for base_url, user_token in self.user_tokens.items(): |
| 191 | + f.write(f'[tokens."{base_url}"]\n') |
| 192 | + f.write(f'token = "{user_token.token}"\n') |
| 193 | + f.write(f'expiration = "{user_token.expiration}"\n') |
| 194 | + |
| 195 | + |
35 | 196 | class NewDeviceFlow(TypedDict): |
36 | 197 | """Matches model of the same name in the backend.""" |
37 | 198 |
|
@@ -91,17 +252,3 @@ def poll_for_token(session: requests.Session, device_code: str, base_api_url: st |
91 | 252 | opt_user_token: UserTokenData | None = res.json() |
92 | 253 | if opt_user_token: |
93 | 254 | return opt_user_token |
94 | | - |
95 | | - |
96 | | -def is_logged_in(data: DefaultFile, logfire_url: str) -> bool: |
97 | | - """Check if the user is logged in. |
98 | | -
|
99 | | - Returns: |
100 | | - True if the user is logged in, False otherwise. |
101 | | - """ |
102 | | - for url, info in data['tokens'].items(): # pragma: no branch |
103 | | - # token expirations are in UTC |
104 | | - expiry_date = datetime.fromisoformat(info['expiration'].rstrip('Z')).replace(tzinfo=timezone.utc) |
105 | | - if url == logfire_url and datetime.now(tz=timezone.utc) < expiry_date: # pragma: no branch |
106 | | - return True |
107 | | - return False # pragma: no cover |
|
0 commit comments