Skip to content

Commit 4e8ac79

Browse files
committed
Add better OAuth Redirect handling to web adapters.
1 parent 637f283 commit 4e8ac79

File tree

2 files changed

+44
-5
lines changed

2 files changed

+44
-5
lines changed

twitchio/web/aio_adapter.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,10 +316,12 @@ async def fetch_token(self, request: web.Request) -> FetchTokenPayload:
316316
if "code" not in request.query:
317317
return FetchTokenPayload(400, response=web.Response(status=400, text="No 'code' parameter provided."))
318318

319+
redirect = self._find_redirect(request)
320+
319321
try:
320322
resp: UserTokenPayload = await self.client._http.user_access_token(
321323
request.query["code"],
322-
redirect_uri=self.redirect_url,
324+
redirect_uri=redirect,
323325
)
324326
except HTTPException as e:
325327
logger.error("Exception raised while fetching Token in <%s>: %s", self.__class__.__qualname__, e)
@@ -378,7 +380,22 @@ async def oauth_callback(self, request: web.Request) -> web.Response:
378380

379381
return payload.response
380382

383+
def _find_redirect(self, request: web.Request) -> str:
384+
stripped = self._domain.removeprefix(f"{self._proto}://")
385+
local = f"{self._proto}://{self._host}"
386+
387+
if request.host.startswith((self._domain, stripped)):
388+
redirect = self.redirect_url
389+
elif request.host.startswith((self._host, local)):
390+
redirect = f"{local}/oauth/callback"
391+
else:
392+
redirect = f"{request.scheme}://{request.host}/oauth/callback"
393+
394+
return redirect
395+
381396
async def oauth_redirect(self, request: web.Request) -> web.Response:
397+
redirect = self._find_redirect(request)
398+
382399
scopes: str | None = request.query.get("scopes", request.query.get("scope", None))
383400
force_verify: bool = request.query.get("force_verify", "false").lower() == "true"
384401

@@ -397,7 +414,7 @@ async def oauth_redirect(self, request: web.Request) -> web.Response:
397414
try:
398415
payload: AuthorizationURLPayload = self.client._http.get_authorization_url(
399416
scopes=scopes_,
400-
redirect_uri=self.redirect_url,
417+
redirect_uri=redirect,
401418
force_verify=force_verify,
402419
)
403420
except Exception as e:

twitchio/web/starlette_adapter.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def __init__(
162162
if eventsub_secret and not 10 <= len(eventsub_secret) <= 100:
163163
raise ValueError("Eventsub Secret must be between 10 and 100 characters long.")
164164

165-
self._domain: str | None = None
165+
self._domain: str
166166
self._proto = "https" if (ssl_keyfile or domain) else "http"
167167

168168
if domain:
@@ -358,10 +358,12 @@ async def fetch_token(self, request: Request) -> FetchTokenPayload:
358358
if "code" not in request.query_params:
359359
return FetchTokenPayload(400, response=Response(status_code=400, content="No 'code' parameter provided."))
360360

361+
redirect = self._find_redirect(request)
362+
361363
try:
362364
resp: UserTokenPayload = await self.client._http.user_access_token(
363365
request.query_params["code"],
364-
redirect_uri=self.redirect_url,
366+
redirect_uri=redirect,
365367
)
366368
except HTTPException as e:
367369
logger.error("Exception raised while fetching Token in <%s>: %s", self.__class__.__qualname__, e)
@@ -380,6 +382,25 @@ async def fetch_token(self, request: Request) -> FetchTokenPayload:
380382
payload=resp,
381383
)
382384

385+
def _find_redirect(self, request: Request) -> str:
386+
stripped = self._domain.removeprefix(f"{self._proto}://")
387+
local = f"{self._proto}://{self._host}"
388+
389+
host = request.url.hostname
390+
scheme = request.url.scheme
391+
392+
if not host:
393+
return self.redirect_url
394+
395+
if host.startswith((self._domain, stripped)):
396+
redirect = self.redirect_url
397+
elif host.startswith((self._host, local)):
398+
redirect = f"{local}:{self._port}/oauth/callback"
399+
else:
400+
redirect = f"{scheme}://{host}/oauth/callback"
401+
402+
return redirect
403+
383404
async def oauth_callback(self, request: Request) -> Response:
384405
"""Default route callback for the OAuth Authentication redirect URL.
385406
@@ -423,6 +444,7 @@ async def oauth_callback(self, request: Request) -> Response:
423444
async def oauth_redirect(self, request: Request) -> Response:
424445
scopes: str | None = request.query_params.get("scopes", None)
425446
force_verify: bool = request.query_params.get("force_verify", "false").lower() == "true"
447+
redirect = self._find_redirect(request)
426448

427449
if not scopes:
428450
scopes = str(self.client._http.scopes) if self.client._http.scopes else None
@@ -439,7 +461,7 @@ async def oauth_redirect(self, request: Request) -> Response:
439461
try:
440462
payload: AuthorizationURLPayload = self.client._http.get_authorization_url(
441463
scopes=scopes_,
442-
redirect_uri=self.redirect_url,
464+
redirect_uri=redirect,
443465
force_verify=force_verify,
444466
)
445467
except Exception as e:

0 commit comments

Comments
 (0)