Skip to content

Commit f5c57e0

Browse files
authored
Break up the login flow and expose the authorization URL to the library user (#174)
* Break up the login flow and expose the authorization URL to the library user * Pass oauth state properly * Test async client creation
1 parent ae16acd commit f5c57e0

File tree

2 files changed

+182
-64
lines changed

2 files changed

+182
-64
lines changed

schwab/auth.py

Lines changed: 88 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from authlib.integrations.httpx_client import AsyncOAuth2Client, OAuth2Client
22
from prompt_toolkit import prompt
33

4+
import collections
45
import contextlib
56
import httpx
67
import json
@@ -27,7 +28,7 @@ def get_logger():
2728
return logging.getLogger(__name__)
2829

2930

30-
def __update_token(token_path):
31+
def __make_update_token_func(token_path):
3132
def update_token(t, *args, **kwargs):
3233
get_logger().info('Updating token to file %s', token_path)
3334

@@ -119,51 +120,6 @@ def wrap_token_in_metadata(self, token):
119120
}
120121

121122

122-
def __fetch_and_register_token_from_redirect(
123-
oauth, redirected_url, api_key, app_secret, token_path,
124-
token_write_func, asyncio, enforce_enums=True):
125-
token = oauth.fetch_token(
126-
TOKEN_ENDPOINT,
127-
authorization_response=redirected_url,
128-
client_id=api_key, auth=(api_key, app_secret))
129-
130-
# Don't emit token details in debug logs
131-
register_redactions(token)
132-
133-
# Set up token writing and perform the initial token write
134-
update_token = (
135-
__update_token(token_path) if token_write_func is None
136-
else token_write_func)
137-
metadata_manager = TokenMetadata(token, int(time.time()), update_token)
138-
update_token = metadata_manager.wrapped_token_write_func()
139-
update_token(token)
140-
141-
# The synchronous and asynchronous versions of the OAuth2Client are similar
142-
# enough that can mostly be used interchangeably. The one currently known
143-
# exception is the token update function: the synchronous version expects a
144-
# synchronous one, the asynchronous requires an async one. The
145-
# oauth_client_update_token variable will contain the appropriate one.
146-
if asyncio:
147-
async def oauth_client_update_token(t, *args, **kwargs):
148-
update_token(t, *args, **kwargs) # pragma: no cover
149-
session_class = AsyncOAuth2Client
150-
client_class = AsyncClient
151-
else:
152-
oauth_client_update_token = update_token
153-
session_class = OAuth2Client
154-
client_class = Client
155-
156-
# Return a new session configured to refresh credentials
157-
return client_class(
158-
api_key,
159-
session_class(api_key,
160-
client_secret=app_secret,
161-
token=token,
162-
update_token=oauth_client_update_token,
163-
leeway=300),
164-
token_metadata=metadata_manager, enforce_enums=enforce_enums)
165-
166-
167123
################################################################################
168124
# client_from_login_flow
169125

@@ -351,9 +307,7 @@ def callback_server():
351307
time.sleep(0.1)
352308

353309
# Open the browser
354-
oauth = OAuth2Client(api_key, redirect_uri=callback_url)
355-
authorization_url, state = oauth.create_authorization_url(
356-
'https://api.schwabapi.com/v1/oauth/authorize')
310+
auth_context = get_auth_context(api_key, callback_url)
357311

358312
print()
359313
print('***********************************************************************')
@@ -363,7 +317,7 @@ def callback_server():
363317
print('browser, captures the resulting OAuth callback, and creates a token')
364318
print('using the result. The authorization URL is:')
365319
print()
366-
print('>>', authorization_url)
320+
print('>>', auth_context.authorization_url)
367321
print()
368322
print('IMPORTANT: Your browser will give you a security warning about an')
369323
print('invalid certificate prior to issuing the redirect. This is because')
@@ -388,7 +342,7 @@ def callback_server():
388342
'this method with interactive=False to skip this input.')
389343

390344
controller = webbrowser.get(requested_browser)
391-
controller.open(authorization_url)
345+
controller.open(auth_context.authorization_url)
392346

393347
# Wait for a response
394348
now = __TIME_TIME()
@@ -420,9 +374,13 @@ def callback_server():
420374
'can set a longer timeout by passing a value of ' +
421375
'callback_timeout to client_from_login_flow.')
422376

423-
return __fetch_and_register_token_from_redirect(
424-
oauth, received_url, api_key, app_secret, token_path,
425-
token_write_func, asyncio, enforce_enums=enforce_enums)
377+
token_write_func = (
378+
__make_update_token_func(token_path) if token_write_func is None
379+
else token_write_func)
380+
381+
return client_from_received_url(
382+
api_key, app_secret, auth_context, received_url,
383+
token_write_func, asyncio, enforce_enums)
426384

427385

428386
################################################################################
@@ -455,8 +413,8 @@ def client_from_token_file(token_path, api_key, app_secret, asyncio=False,
455413
load = __token_loader(token_path)
456414

457415
return client_from_access_functions(
458-
api_key, app_secret, load, __update_token(token_path), asyncio=asyncio,
459-
enforce_enums=enforce_enums)
416+
api_key, app_secret, load, __make_update_token_func(token_path),
417+
asyncio=asyncio, enforce_enums=enforce_enums)
460418

461419

462420
################################################################################
@@ -494,9 +452,7 @@ def client_from_manual_flow(api_key, app_secret, callback_url, token_path,
494452
get_logger().info('Creating new token with callback URL \'%s\' ' +
495453
'and token path \'%s\'', callback_url, token_path)
496454

497-
oauth = OAuth2Client(api_key, redirect_uri=callback_url)
498-
authorization_url, state = oauth.create_authorization_url(
499-
'https://api.schwabapi.com/v1/oauth/authorize')
455+
auth_context = get_auth_context(api_key, callback_url)
500456

501457
print('\n**************************************************************\n')
502458
print('This is the manual login and token creation flow for schwab-py.')
@@ -505,7 +461,7 @@ def client_from_manual_flow(api_key, app_secret, callback_url, token_path,
505461
print(' 1. Open the following link by copy-pasting it into the browser')
506462
print(' of your choice:')
507463
print()
508-
print(' ' + authorization_url)
464+
print(' ' + auth_context.authorization_url)
509465
print()
510466
print(' 2. Log in with your account credentials. You may be asked to')
511467
print(' perform two-factor authentication using text messaging or')
@@ -529,11 +485,15 @@ def client_from_manual_flow(api_key, app_secret, callback_url, token_path,
529485
'and update your callback URL to begin with \'https\' ' +
530486
'to stop seeing this message.').format(callback_url))
531487

532-
redirected_url = prompt('Redirect URL> ').strip()
488+
received_url = prompt('Redirect URL> ').strip()
489+
490+
token_write_func = (
491+
__make_update_token_func(token_path) if token_write_func is None
492+
else token_write_func)
533493

534-
return __fetch_and_register_token_from_redirect(
535-
oauth, redirected_url, api_key, app_secret, token_path, token_write_func,
536-
asyncio, enforce_enums=enforce_enums)
494+
return client_from_received_url(
495+
api_key, app_secret, auth_context, received_url, token_write_func,
496+
asyncio, enforce_enums)
537497

538498

539499
################################################################################
@@ -611,6 +571,70 @@ async def oauth_client_update_token(t, *args, **kwargs):
611571
enforce_enums=enforce_enums)
612572

613573

574+
################################################################################
575+
# Tools for incorporating token generation into webapp workflows
576+
577+
578+
AuthContext = collections.namedtuple(
579+
'AuthContext', ['callback_url', 'authorization_url', 'state'])
580+
581+
def get_auth_context(api_key, callback_url, state=None):
582+
oauth = OAuth2Client(api_key, redirect_uri=callback_url)
583+
authorization_url, state = oauth.create_authorization_url(
584+
'https://api.schwabapi.com/v1/oauth/authorize',
585+
state=state)
586+
587+
return AuthContext(callback_url, authorization_url, state)
588+
589+
590+
def client_from_received_url(
591+
api_key, app_secret, auth_context, received_url, token_write_func,
592+
asyncio=False, enforce_enums=True):
593+
# XXX: The AuthContext must be serializable, which means the original
594+
# OAuth2Client created in get_auth_context cannot be passed around.
595+
# Instead, we reconstruct it here.
596+
oauth = OAuth2Client(api_key, redirect_uri=auth_context.callback_url)
597+
598+
token = oauth.fetch_token(
599+
TOKEN_ENDPOINT,
600+
authorization_response=received_url,
601+
client_id=api_key, auth=(api_key, app_secret),
602+
state=auth_context.state)
603+
604+
# Don't emit token details in debug logs
605+
register_redactions(token)
606+
607+
# Set up token writing and perform the initial token write
608+
metadata_manager = TokenMetadata(token, int(time.time()), token_write_func)
609+
token_write_func = metadata_manager.wrapped_token_write_func()
610+
token_write_func(token)
611+
612+
# The synchronous and asynchronous versions of the OAuth2Client are similar
613+
# enough that can mostly be used interchangeably. The one currently known
614+
# exception is the token update function: the synchronous version expects a
615+
# synchronous one, the asynchronous requires an async one. The
616+
# oauth_client_update_token variable will contain the appropriate one.
617+
if asyncio:
618+
async def oauth_client_update_token(t, *args, **kwargs):
619+
token_write_func(t, *args, **kwargs) # pragma: no cover
620+
session_class = AsyncOAuth2Client
621+
client_class = AsyncClient
622+
else:
623+
oauth_client_update_token = token_write_func
624+
session_class = OAuth2Client
625+
client_class = Client
626+
627+
# Return a new session configured to refresh credentials
628+
return client_class(
629+
api_key,
630+
session_class(api_key,
631+
client_secret=app_secret,
632+
token=token,
633+
update_token=oauth_client_update_token,
634+
leeway=300),
635+
token_metadata=metadata_manager, enforce_enums=enforce_enums)
636+
637+
614638
################################################################################
615639
# easy_client
616640

tests/auth_test.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,100 @@ def token_write_func(token):
536536
API_KEY, _, token_metadata=_, enforce_enums=True)
537537

538538

539+
# Note the client_from_received_url is called internally by the other client
540+
# generation functions, so testing here is kept light
541+
class ClientFromReceivedUrl(unittest.TestCase):
542+
543+
def setUp(self):
544+
self.tmp_dir = tempfile.TemporaryDirectory()
545+
self.token_path = os.path.join(self.tmp_dir.name, 'token.json')
546+
self.raw_token = {'token': 'yes'}
547+
548+
@no_duplicates
549+
@patch('schwab.auth.Client')
550+
@patch('schwab.auth.AsyncClient')
551+
@patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient)
552+
@patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient)
553+
@patch('time.time', MagicMock(return_value=MOCK_NOW))
554+
def test_success_sync(
555+
self, async_session, sync_session, async_client, client):
556+
AUTH_URL = 'https://auth.url.com'
557+
558+
sync_session.return_value = sync_session
559+
sync_session.create_authorization_url.return_value = \
560+
AUTH_URL, 'oauth state'
561+
sync_session.fetch_token.return_value = self.raw_token
562+
563+
auth_context = auth.get_auth_context(API_KEY, CALLBACK_URL)
564+
self.assertEqual(AUTH_URL, auth_context.authorization_url)
565+
self.assertEqual('oauth state', auth_context.state)
566+
567+
client.return_value = 'returned client'
568+
token_capture = []
569+
auth.client_from_received_url(
570+
API_KEY, APP_SECRET, auth_context,
571+
'http://redirect.url.com/?data',
572+
lambda token: token_capture.append(token))
573+
574+
client.assert_called_once()
575+
async_client.assert_not_called()
576+
577+
# Verify that the oauth state is correctly passed along
578+
sync_session.fetch_token.assert_called_once_with(
579+
_,
580+
authorization_response=_,
581+
client_id=_,
582+
auth=_,
583+
state='oauth state')
584+
585+
self.assertEqual([{
586+
'creation_timestamp': MOCK_NOW,
587+
'token': self.raw_token
588+
}], token_capture)
589+
590+
591+
@no_duplicates
592+
@patch('schwab.auth.Client')
593+
@patch('schwab.auth.AsyncClient')
594+
@patch('schwab.auth.OAuth2Client', new_callable=MockOAuthClient)
595+
@patch('schwab.auth.AsyncOAuth2Client', new_callable=MockAsyncOAuthClient)
596+
@patch('time.time', MagicMock(return_value=MOCK_NOW))
597+
def test_success_async(
598+
self, async_session, sync_session, async_client, client):
599+
AUTH_URL = 'https://auth.url.com'
600+
601+
sync_session.return_value = sync_session
602+
sync_session.create_authorization_url.return_value = \
603+
AUTH_URL, 'oauth state'
604+
sync_session.fetch_token.return_value = self.raw_token
605+
606+
auth_context = auth.get_auth_context(API_KEY, CALLBACK_URL)
607+
608+
client.return_value = 'returned client'
609+
token_capture = []
610+
auth.client_from_received_url(
611+
API_KEY, APP_SECRET, auth_context,
612+
'http://redirect.url.com/?data',
613+
lambda token: token_capture.append(token),
614+
asyncio=True)
615+
616+
async_client.assert_called_once()
617+
client.assert_not_called()
618+
619+
# Verify that the oauth state is correctly passed along
620+
sync_session.fetch_token.assert_called_once_with(
621+
_,
622+
authorization_response=_,
623+
client_id=_,
624+
auth=_,
625+
state='oauth state')
626+
627+
self.assertEqual([{
628+
'creation_timestamp': MOCK_NOW,
629+
'token': self.raw_token
630+
}], token_capture)
631+
632+
539633
class ClientFromManualFlow(unittest.TestCase):
540634

541635
def setUp(self):

0 commit comments

Comments
 (0)