|
1 | | -import time |
2 | 1 | from enum import StrEnum |
3 | 2 | from typing import Any |
4 | 3 |
|
5 | | -import httpx |
6 | | -import structlog |
7 | 4 | from githubkit import ( |
8 | 5 | AppAuthStrategy, |
9 | 6 | AppInstallationAuthStrategy, |
10 | 7 | GitHub, |
11 | 8 | Response, |
12 | 9 | TokenAuthStrategy, |
13 | | - utils, |
14 | | - webhooks, |
15 | 10 | ) |
16 | | -from githubkit.typing import Missing |
17 | | -from pydantic import BaseModel, Field |
18 | 11 |
|
19 | 12 | from polar.config import settings |
20 | | -from polar.locker import Locker |
21 | | -from polar.models.user import OAuthAccount, OAuthPlatform, User |
22 | | -from polar.postgres import AsyncSession |
23 | | -from polar.user.oauth_service import oauth_account_service |
24 | | - |
25 | | -log = structlog.get_logger() |
26 | 13 |
|
27 | 14 |
|
28 | 15 | class UnexpectedStatusCode(Exception): ... |
@@ -72,123 +59,6 @@ def ensure_expected_response( |
72 | 59 | ############################################################################### |
73 | 60 |
|
74 | 61 |
|
75 | | -class RefreshAccessToken(BaseModel): |
76 | | - access_token: str = Field(default=...) |
77 | | - # The number of seconds until access_token expires (will always be 28800) |
78 | | - expires_in: int = Field(default=...) |
79 | | - # A new refresh token (is only set if the app is using expiring refresh tokens) |
80 | | - refresh_token: str | None = Field(default=...) |
81 | | - # The value will always be 15897600 (6 months) unless token expiration is disabled |
82 | | - refresh_token_expires_in: int | None = Field(default=...) |
83 | | - # Always an empty string |
84 | | - scope: str = Field(default=...) |
85 | | - # Always "bearer" |
86 | | - token_type: str = Field(default=...) |
87 | | - |
88 | | - |
89 | | -async def get_user_client( |
90 | | - session: AsyncSession, locker: Locker, user: User |
91 | | -) -> GitHub[TokenAuthStrategy]: |
92 | | - oauth = await oauth_account_service.get_by_platform_and_user_id( |
93 | | - session, OAuthPlatform.github, user.id |
94 | | - ) |
95 | | - if not oauth: |
96 | | - raise Exception("no github oauth account found") |
97 | | - |
98 | | - return await get_refreshed_oauth_client(session, locker, oauth) |
99 | | - |
100 | | - |
101 | | -async def refresh_oauth_account( |
102 | | - session: AsyncSession, locker: Locker, oauth: OAuthAccount |
103 | | -) -> OAuthAccount: |
104 | | - if oauth.platform != OAuthPlatform.github: |
105 | | - raise Exception("unexpected platform") |
106 | | - |
107 | | - if not oauth.should_refresh_access_token(): |
108 | | - return oauth |
109 | | - |
110 | | - async with locker.lock( |
111 | | - f"oauth_refresh:{oauth.id}", |
112 | | - timeout=10.0, |
113 | | - blocking_timeout=10.0, |
114 | | - ): |
115 | | - # first, reload from DB, a concurrent process might have already refreshed this token |
116 | | - # (and used the refresh token). |
117 | | - oauth_db = await oauth_account_service.get(session, oauth.id) |
118 | | - |
119 | | - if not oauth_db: |
120 | | - raise Exception("oauth account not found") |
121 | | - |
122 | | - # token is already refreshed |
123 | | - if not oauth_db.should_refresh_access_token(): |
124 | | - return oauth_db |
125 | | - |
126 | | - # refresh token |
127 | | - async with httpx.AsyncClient() as http_client: |
128 | | - response = await http_client.post( |
129 | | - "https://github.com/login/oauth/access_token", |
130 | | - params={ |
131 | | - "client_id": settings.GITHUB_CLIENT_ID, |
132 | | - "client_secret": settings.GITHUB_CLIENT_SECRET, |
133 | | - "refresh_token": oauth.refresh_token, |
134 | | - "grant_type": "refresh_token", |
135 | | - }, |
136 | | - headers={"Accept": "application/json"}, |
137 | | - ) |
138 | | - if response.status_code != 200: |
139 | | - log.error( |
140 | | - "github.auth.refresh.error", |
141 | | - user_id=oauth_db.user_id, |
142 | | - oauth_id=oauth_db.id, |
143 | | - http_code=response.status_code, |
144 | | - ) |
145 | | - return oauth_db |
146 | | - |
147 | | - data = response.json() |
148 | | - # GitHub returns 200 in case of errors, but with an error payload |
149 | | - error = data.get("error", None) |
150 | | - if error: |
151 | | - log.error( |
152 | | - "github.auth.refresh.error", |
153 | | - user_id=oauth_db.user_id, |
154 | | - oauth_id=oauth_db.id, |
155 | | - http_code=response.status_code, |
156 | | - error=error, |
157 | | - error_description=data.get("error_description", None), |
158 | | - ) |
159 | | - return oauth_db |
160 | | - |
161 | | - refreshed = RefreshAccessToken.model_validate(data) |
162 | | - |
163 | | - # update |
164 | | - epoch_now = int(time.time()) |
165 | | - oauth_db.access_token = refreshed.access_token |
166 | | - oauth_db.expires_at = epoch_now + refreshed.expires_in |
167 | | - if refreshed.refresh_token: |
168 | | - oauth_db.refresh_token = refreshed.refresh_token |
169 | | - |
170 | | - if refreshed.refresh_token_expires_in: |
171 | | - oauth_db.refresh_token_expires_at = ( |
172 | | - epoch_now + refreshed.refresh_token_expires_in |
173 | | - ) |
174 | | - |
175 | | - log.info( |
176 | | - "github.auth.refresh.succeeded", |
177 | | - user_id=oauth.user_id, |
178 | | - platform=oauth.platform, |
179 | | - ) |
180 | | - session.add(oauth_db) |
181 | | - await session.flush() |
182 | | - return oauth_db |
183 | | - |
184 | | - |
185 | | -async def get_refreshed_oauth_client( |
186 | | - session: AsyncSession, locker: Locker, oauth: OAuthAccount |
187 | | -) -> GitHub[TokenAuthStrategy]: |
188 | | - refreshed_oauth = await refresh_oauth_account(session, locker, oauth) |
189 | | - return get_client(refreshed_oauth.access_token) |
190 | | - |
191 | | - |
192 | 62 | def get_client(access_token: str) -> GitHub[TokenAuthStrategy]: |
193 | 63 | return GitHub(access_token, http_cache=False) |
194 | 64 |
|
@@ -231,12 +101,8 @@ def get_app_installation_client( |
231 | 101 | "get_client", |
232 | 102 | "get_app_client", |
233 | 103 | "get_app_installation_client", |
234 | | - "get_user_client", |
235 | 104 | "GitHub", |
236 | | - "Missing", |
237 | 105 | "AppInstallationAuthStrategy", |
238 | 106 | "TokenAuthStrategy", |
239 | | - "utils", |
240 | 107 | "Response", |
241 | | - "webhooks", |
242 | 108 | ] |
0 commit comments