Skip to content

Commit 31187e1

Browse files
committed
post overhaul cleanup pt 2
1 parent 12db739 commit 31187e1

File tree

10 files changed

+224
-77
lines changed

10 files changed

+224
-77
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
.idea/
22
__pycache__/
33
.env
4+
token_data.json

accounts.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ class Accounts:
66
def __init__(self, client):
77
self.client = client
88
self.logger = logging.getLogger(__name__)
9-
self.base_url = client.config.TRADER_BASE_URL
9+
self.base_url = client.config.ACCOUNTS_BASE_URL
1010

1111
def get_account_numbers(self):
1212
"""Retrieve account numbers associated with the user's profile."""
1313
try:
14-
response = self.client.get(f'{self.base_url}/accountNumbers')
15-
return response.json()
14+
print(f'{self.base_url}/accountNumbers')
15+
return self.client.make_request(f'{self.base_url}/accountNumbers')
1616
except Exception as e:
1717
self.logger.error(f"Failed to get account numbers: {e}")
1818
return None
@@ -21,7 +21,7 @@ def get_all_accounts(self, fields=None):
2121
"""Retrieve detailed information for all linked accounts, optionally filtering the fields."""
2222
params = {'fields': fields} if fields else {}
2323
try:
24-
return self.client.get(f'{self.base_url}', params=params)
24+
return self.client.make_request(f'{self.base_url}', params=params)
2525
except Exception as e:
2626
self.logger.error(f"Failed to get all accounts: {e}")
2727
return None
@@ -33,7 +33,7 @@ def get_account(self, account_hash, fields=None):
3333
return None
3434
params = {'fields': fields} if fields else {}
3535
try:
36-
return self.client.get(f'{self.base_url}/{account_hash}', params=params)
36+
return self.client.make_request(f'{self.base_url}/{account_hash}', params=params)
3737
except Exception as e:
3838
self.logger.error(f"Failed to get account {account_hash}: {e}")
3939
return None
@@ -50,7 +50,7 @@ def get_account_transactions(self, account_hash, start_date, end_date, types=Non
5050
'symbol': symbol
5151
}
5252
try:
53-
return self.client.get(f'{self.base_url}/{account_hash}/transactions', params=params)
53+
return self.client.make_request(f'{self.base_url}/{account_hash}/transactions', params=params)
5454
except Exception as e:
5555
self.logger.error(f"Failed to get transactions for account {account_hash}: {e}")
5656
return None

api_client.py

Lines changed: 113 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,148 @@
1+
import random
2+
from time import sleep
13
import requests
2-
import logging
4+
import webbrowser
5+
import base64
36
import json
47
from datetime import datetime, timedelta
8+
import logging
59
from config import APIConfig
10+
from color_print import ColorPrint
611

712

813
class APIClient:
914
def __init__(self):
15+
self.account_numbers = None
16+
self.config = APIConfig
1017
self.session = requests.Session()
1118
self.setup_logging()
19+
self.token_info = self.load_token()
1220

13-
if not self.validate_credentials():
14-
logging.error("Invalid or missing credentials. Please check your configuration.")
15-
exit(1)
16-
17-
self.token_info = self.load_token() or self.authenticate()
21+
# Validate and refresh token or reauthorize if necessary
22+
if not self.token_info or not self.ensure_valid_token():
23+
self.manual_authorization_flow()
1824

1925
def setup_logging(self):
2026
logging.basicConfig(**APIConfig.LOGGING_CONFIG)
2127
self.logger = logging.getLogger(__name__)
2228

23-
def validate_credentials(self):
24-
return all([APIConfig.APP_KEY, APIConfig.APP_SECRET, APIConfig.CALLBACK_URL])
29+
def ensure_valid_token(self):
30+
"""Ensure the token is valid, refresh if possible, otherwise prompt for reauthorization."""
31+
if self.token_info:
32+
if self.validate_token():
33+
self.logger.info("Token loaded and valid.")
34+
return True
35+
elif 'refresh_token' in self.token_info:
36+
self.logger.info("Access token expired. Attempting to refresh.")
37+
if self.refresh_access_token():
38+
return True
39+
self.logger.warning("Token invalid and could not be refreshed.")
40+
return False
41+
42+
def manual_authorization_flow(self):
43+
""" Handle the manual steps required to get the authorization code from the user. """
44+
self.logger.info("Starting manual authorization flow.")
45+
auth_url = f"{APIConfig.API_BASE_URL}/v1/oauth/authorize?client_id={APIConfig.APP_KEY}&redirect_uri={APIConfig.CALLBACK_URL}&response_type=code"
46+
webbrowser.open(auth_url)
47+
self.logger.info(f"Please authorize the application by visiting: {auth_url}")
48+
response_url = ColorPrint.input(
49+
"After authorizing, wait for it to load (<1min) and paste the WHOLE url here: ")
50+
authorization_code = f"{response_url[response_url.index('code=') + 5:response_url.index('%40')]}@"
51+
# session = response_url[response_url.index("session=")+8:]
52+
self.exchange_authorization_code_for_tokens(authorization_code)
2553

26-
def authenticate(self):
27-
"""Authenticate with the API and store the new token information."""
54+
def exchange_authorization_code_for_tokens(self, code):
55+
""" Exchange the authorization code for access and refresh tokens. """
2856
data = {
29-
'grant_type': 'client_credentials',
30-
'client_id': APIConfig.APP_KEY,
31-
'client_secret': APIConfig.APP_SECRET
57+
'grant_type': 'authorization_code',
58+
'code': code,
59+
'redirect_uri': self.config.CALLBACK_URL
3260
}
33-
response = self.session.post(f"{APIConfig.API_BASE_URL}/v1/oauth/token", data=data)
34-
response.raise_for_status()
35-
token_data = response.json()
36-
self.save_token(token_data)
37-
return token_data
61+
self.post_token_request(data)
62+
63+
def post_token_request(self, data):
64+
""" Generalized token request handling. """
65+
headers = {
66+
'Authorization': f'Basic {base64.b64encode(f"{self.config.APP_KEY}:{self.config.APP_SECRET}".encode()).decode()}',
67+
'Content-Type': 'application/x-www-form-urlencoded'
68+
}
69+
response = self.session.post(f"{self.config.API_BASE_URL}/v1/oauth/token", headers=headers, data=data)
70+
if response.ok:
71+
self.save_token(response.json())
72+
self.load_token()
73+
self.logger.info("Tokens successfully updated.")
74+
return True
75+
else:
76+
self.logger.error("Failed to obtain tokens.")
77+
response.raise_for_status()
78+
79+
def refresh_access_token(self):
80+
"""Use the refresh token to obtain a new access token and validate it."""
81+
82+
data = {
83+
'grant_type': 'refresh_token',
84+
'refresh_token': self.token_info['refresh_token']
85+
}
86+
if not self.post_token_request(data):
87+
self.logger.error("Failed to refresh access token.")
88+
return False
89+
90+
return self.validate_token()
3891

3992
def save_token(self, token_data):
40-
"""Saves the token data securely to a file."""
93+
""" Save token data securely. """
4194
token_data['expires_at'] = (datetime.now() + timedelta(seconds=token_data['expires_in'])).isoformat()
4295
with open('token_data.json', 'w') as f:
4396
json.dump(token_data, f)
4497
self.logger.info("Token data saved successfully.")
4598

4699
def load_token(self):
47-
"""Loads the token data from a file if it is still valid."""
100+
""" Load token data. """
48101
try:
49102
with open('token_data.json', 'r') as f:
50103
token_data = json.load(f)
51-
if datetime.now() < datetime.fromisoformat(token_data['expires_at']):
52-
self.logger.info("Token loaded successfully from file.")
53-
return token_data
54-
except (FileNotFoundError, KeyError, ValueError) as e:
104+
return token_data
105+
except Exception as e:
55106
self.logger.warning(f"Loading token failed: {e}")
56107
return None
57108

58-
def make_request(self, method, endpoint, **kwargs):
59-
"""Makes an HTTP request using the authenticated session."""
60-
url = f"{APIConfig.API_BASE_URL}{endpoint}"
61-
response = self.session.request(method, url, **kwargs)
62-
if response.status_code == 401: # Token expired
63-
self.logger.warning("Token expired. Refreshing token...")
64-
self.token_info = self.authenticate()
65-
response = self.session.request(method, url, **kwargs)
109+
def validate_token(self):
110+
""" Validate the current token's validity. """
111+
if self.token_info and datetime.now() < datetime.fromisoformat(self.token_info['expires_at']):
112+
return True
113+
else:
114+
# get AAPL to validate token
115+
params = {'symbol': 'AAPL'}
116+
response = self.make_request(endpoint=f"{self.config.MARKET_DATA_BASE_URL}/chains", params=params, validating=True)
117+
print(response)
118+
if response:
119+
self.logger.info("Token validated successfully.")
120+
# self.account_numbers = response.json()
121+
return True
122+
self.logger.warning("Token validation failed.")
123+
return False
124+
125+
def make_request(self, endpoint, method="GET", **kwargs):
126+
sleep(0.5 + random.randint(0, 1000) / 1000)
127+
""" Make authenticated HTTP requests. """
128+
if 'validating' not in kwargs:
129+
if not self.validate_token():
130+
self.logger.info("Token expired or invalid, re-authenticating.")
131+
self.manual_authorization_flow()
132+
kwargs.pop('validating', None)
133+
if self.config.API_BASE_URL not in endpoint:
134+
url = f"{self.config.API_BASE_URL}{endpoint}"
135+
else:
136+
url = endpoint
137+
print(f"Making request to {url} with method {method} and kwargs {kwargs} (validating already popped if present)")
138+
headers = {'Authorization': f"Bearer {self.token_info['access_token']}"}
139+
response = self.session.request(method, url, headers=headers, **kwargs)
140+
print(response.status_code)
141+
print(response.text)
142+
if response.status_code == 401:
143+
self.logger.warning("Token expired during request. Refreshing token...")
144+
self.manual_authorization_flow()
145+
headers = {'Authorization': f"Bearer {self.token_info['access_token']}"}
146+
response = self.session.request(method, url, headers=headers, **kwargs)
66147
response.raise_for_status()
67148
return response.json()

color_print.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
class ColorPrint:
2+
COLORS = {
3+
'info': '\033[92m[INFO]: \033[00m',
4+
'warning': '\033[93m[WARN]: \033[00m',
5+
'error': '\033[91m[ERROR]: \033[00m',
6+
'input': '\033[94m[INPUT]: \033[00m',
7+
'user': '\033[1;31m[USER]: \033[00m'
8+
}
9+
10+
@staticmethod
11+
def print(message_type, message, end="\n"):
12+
print(f"{ColorPrint.COLORS.get(message_type, '[UNKNOWN]: ')}{message}", end=end)
13+
14+
@staticmethod
15+
def input(message):
16+
return input(f"{ColorPrint.COLORS['input']}{message}")
17+
18+
19+
if __name__ == '__main__':
20+
ColorPrint.print('info', 'This is an informational message')

config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
class APIConfig:
88
API_BASE_URL = "https://api.schwabapi.com"
99
TRADER_BASE_URL = f"{API_BASE_URL}/trader/v1"
10+
ACCOUNTS_BASE_URL = f"{TRADER_BASE_URL}/accounts"
1011
MARKET_DATA_BASE_URL = f"{API_BASE_URL}/marketdata/v1"
11-
ORDER_BASE_URL = f"{API_BASE_URL}/accounts"
12+
ORDERS_BASE_URL = ACCOUNTS_BASE_URL
1213
REQUEST_TIMEOUT = 30 # Timeout for API requests in seconds
1314
RETRY_STRATEGY = {
1415
'total': 3, # Total number of retries to allow

main.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from datetime import datetime, timedelta
22
from api_client import APIClient
33
from accounts import Accounts
4+
from market_data import Quotes, Options, PriceHistory, Movers, MarketHours, Instruments
45
from orders import Orders
56

67

@@ -10,13 +11,16 @@ def main():
1011
orders_api = Orders(client)
1112

1213
# Get account numbers for linked accounts
13-
print(accounts_api.get_account_numbers())
14+
# print(accounts_api.get_account_numbers()) # working
1415

1516
# Get positions for linked accounts
16-
print(accounts_api.get_all_accounts().json())
17+
# print(accounts_api.get_all_accounts()) # working
18+
19+
sample_account = client.account_numbers[0]
20+
account_hash = sample_account['accountHash']
1721

1822
# Get specific account positions
19-
print(accounts_api.get_account(fields="positions").json())
23+
# print(accounts_api.get_account(fields="positions"))
2024

2125
# Get up to 3000 orders for an account for the past week
2226
print(orders_api.get_orders(3000, datetime.now() - timedelta(days=7), datetime.now()).json())
@@ -42,22 +46,19 @@ def main():
4246
# print(orders_api.get_order('account_hash', order_id).json())
4347

4448
# Get up to 3000 orders for all accounts for the past week
45-
print(orders_api.get_orders(3000, datetime.now() - timedelta(days=7), datetime.now()).json())
49+
print(orders_api.get_orders(account_hash=account_hash, max_results=3000, from_entered_time=datetime.now() - timedelta(days=7), to_entered_time=datetime.now()))
4650

4751
# Get all transactions for an account
4852
print(accounts_api.get_account_transactions('account_hash', datetime.now() - timedelta(days=7), datetime.now(),
4953
"TRADE").json())
5054

51-
# Get user preferences for an account
52-
print(accounts_api.get_user_preferences('account_hash').json())
53-
5455
# Market-data-related requests
55-
quotes = market_data_api.Quotes(market_data_api)
56-
options = market_data_api.Options(market_data_api)
57-
price_history = market_data_api.PriceHistory(market_data_api)
58-
movers = market_data_api.Movers(market_data_api)
59-
market_hours = market_data_api.MarketHours(market_data_api)
60-
instruments = market_data_api.Instruments(market_data_api)
56+
quotes = Quotes(client)
57+
options = Options(client)
58+
price_history = PriceHistory(client)
59+
movers = Movers(client)
60+
market_hours = MarketHours(client)
61+
instruments = Instruments(client)
6162

6263
# Get a list of quotes
6364
print(quotes.get_list(["AAPL", "AMD"]).json())

0 commit comments

Comments
 (0)