Skip to content

Commit 8938576

Browse files
authored
Merge pull request #2 from Patch-Code-Prosperity/Cfomodz-refactor-1
Cfomodz refactor 1 - This PR Completes the initial overhaul and begins trying to get a stable main branch so we can start working on the fun stuff!
2 parents 97fc62e + 0c562cc commit 8938576

File tree

7 files changed

+347
-42
lines changed

7 files changed

+347
-42
lines changed

accounts.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from config import TRADER_BASE_URL
2+
import logging
3+
4+
class Accounts:
5+
def __init__(self, client):
6+
self.client = client
7+
self.logger = logging.getLogger(__name__)
8+
self.base_url = TRADER_BASE_URL
9+
10+
def get_account_numbers(self):
11+
"""Retrieve account numbers associated with the user's profile."""
12+
try:
13+
return self.client.get(f'{self.base_url}/accountNumbers')
14+
except Exception as e:
15+
self.logger.error(f"Failed to get account numbers: {e}")
16+
return None
17+
18+
def get_all_accounts(self, fields=None):
19+
"""Retrieve detailed information for all linked accounts, optionally filtering the fields."""
20+
params = {'fields': fields} if fields else {}
21+
try:
22+
return self.client.get(f'{self.base_url}', params=params)
23+
except Exception as e:
24+
self.logger.error(f"Failed to get all accounts: {e}")
25+
return None
26+
27+
def get_account(self, account_hash, fields=None):
28+
"""Retrieve detailed information for a specific account using its hash."""
29+
if not account_hash:
30+
self.logger.error("Account hash is required for getting account details")
31+
return None
32+
params = {'fields': fields} if fields else {}
33+
try:
34+
return self.client.get(f'{self.base_url}/{account_hash}', params=params)
35+
except Exception as e:
36+
self.logger.error(f"Failed to get account {account_hash}: {e}")
37+
return None
38+
39+
def get_account_transactions(self, account_hash, start_date, end_date, types=None, symbol=None):
40+
"""Retrieve transactions for a specific account over a specified date range."""
41+
if not (isinstance(start_date, datetime.datetime) and isinstance(end_date, datetime.datetime)):
42+
self.logger.error("Invalid date format. Dates must be datetime objects")
43+
return None
44+
params = {
45+
'startDate': start_date.isoformat(),
46+
'endDate': end_date.isoformat(),
47+
'types': types,
48+
'symbol': symbol
49+
}
50+
try:
51+
return self.client.get(f'{self.base_url}/{account_hash}/transactions', params=params)
52+
except Exception as e:
53+
self.logger.error(f"Failed to get transactions for account {account_hash}: {e}")
54+
return None

api_client.py

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
import requests
2+
import logging
3+
import json
4+
from datetime import datetime, timedelta
5+
from config import APIConfig
6+
7+
class APIClient:
8+
def __init__(self):
9+
self.session = requests.Session()
10+
self.setup_logging()
11+
12+
if not self.validate_credentials():
13+
logging.error("Invalid or missing credentials. Please check your configuration.")
14+
exit(1)
15+
16+
self.token_info = self.load_token() or self.authenticate()
17+
18+
def setup_logging(self):
19+
logging.basicConfig(**APIConfig.LOGGING_CONFIG)
20+
self.logger = logging.getLogger(__name__)
21+
22+
def validate_credentials(self):
23+
return all([APIConfig.APP_KEY, APIConfig.APP_SECRET, APIConfig.CALLBACK_URL])
24+
25+
def authenticate(self):
26+
"""Authenticate with the API and store the new token information."""
27+
data = {
28+
'grant_type': 'client_credentials',
29+
'client_id': APIConfig.APP_KEY,
30+
'client_secret': APIConfig.APP_SECRET
31+
}
32+
response = self.session.post(f"{APIConfig.API_BASE_URL}/v1/oauth/token", data=data)
33+
response.raise_for_status()
34+
token_data = response.json()
35+
self.save_token(token_data)
36+
return token_data
37+
38+
def save_token(self, token_data):
39+
"""Saves the token data securely to a file."""
40+
token_data['expires_at'] = (datetime.now() + timedelta(seconds=token_data['expires_in'])).isoformat()
41+
with open('token_data.json', 'w') as f:
42+
json.dump(token_data, f)
43+
self.logger.info("Token data saved successfully.")
44+
45+
def load_token(self):
46+
"""Loads the token data from a file if it is still valid."""
47+
try:
48+
with open('token_data.json', 'r') as f:
49+
token_data = json.load(f)
50+
if datetime.now() < datetime.fromisoformat(token_data['expires_at']):
51+
self.logger.info("Token loaded successfully from file.")
52+
return token_data
53+
except (FileNotFoundError, KeyError, ValueError) as e:
54+
self.logger.warning(f"Loading token failed: {e}")
55+
return None
56+
57+
def make_request(self, method, endpoint, **kwargs):
58+
"""Makes an HTTP request using the authenticated session."""
59+
url = f"{APIConfig.API_BASE_URL}{endpoint}"
60+
response = self.session.request(method, url, **kwargs)
61+
if response.status_code == 401: # Token expired
62+
self.logger.warning("Token expired. Refreshing token...")
63+
self.token_info = self.authenticate()
64+
response = self.session.request(method, url, **kwargs)
65+
response.raise_for_status()
66+
return response.json()

api_utilities.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import datetime
2+
3+
class ParameterParser:
4+
@staticmethod
5+
def clean_params(params):
6+
"""
7+
Removes None values from a dictionary of parameters.
8+
This ensures that only valid parameters are sent in API requests.
9+
10+
Parameters:
11+
params (dict): A dictionary containing parameter names and values.
12+
13+
Returns:
14+
dict: A dictionary with all None values removed.
15+
"""
16+
return {k: v for k, v in params.items() if v is not None}
17+
18+
class DateTimeConverter:
19+
@staticmethod
20+
def convert_time(dt=None, format_type="8601"):
21+
"""
22+
Converts datetime objects into a string format according to a specified format type.
23+
Supports ISO 8601, epoch time in milliseconds, and custom date formats.
24+
25+
Parameters:
26+
dt (datetime.datetime, optional): The datetime object to convert. Defaults to None.
27+
format_type (str, optional): The type of format to convert the datetime into.
28+
Supported types are "8601" for ISO8601 format, "epoch" for epoch time in milliseconds,
29+
and "YYYY-MM-DD" for custom date format. Defaults to "8601".
30+
31+
Returns:
32+
str or int: The formatted date string or integer (for epoch), or None if dt is None.
33+
"""
34+
if dt is None:
35+
return None
36+
37+
formats = {
38+
"8601": lambda x: x.isoformat()[:-3] + 'Z', # Reduces to milliseconds and appends 'Z' to denote UTC time
39+
"epoch": lambda x: int(x.timestamp() * 1000), # Converts to milliseconds since the Unix epoch
40+
"YYYY-MM-DD": lambda x: x.strftime("%Y-%m-%d") # Formats date as YYYY-MM-DD
41+
}
42+
return formats.get(format_type, lambda x: x)(dt)
43+
44+
def format_list(items):
45+
"""
46+
Converts a list of items into a comma-separated string. If the input is not a list, it returns the input as is.
47+
This is used to format lists for parameters in API requests where multiple values can be passed as comma-separated.
48+
49+
Parameters:
50+
items (list or str): A list of items or a single string item.
51+
52+
Returns:
53+
str: A comma-separated string if items is a list, otherwise the original input string.
54+
"""
55+
if items is None:
56+
return None
57+
return ",".join(items) if isinstance(items, list) else items

config.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import os
2+
from dotenv import load_dotenv
3+
4+
load_dotenv()
5+
6+
class APIConfig:
7+
API_BASE_URL = "https://api.schwabapi.com"
8+
TRADER_BASE_URL = f"{API_BASE_URL}/trader/v1"
9+
MARKET_DATA_BASE_URL = f"{API_BASE_URL}/marketdata/v1"
10+
ORDER_BASE_URL = f"{API_BASE_URL}/accounts"
11+
REQUEST_TIMEOUT = 30 # Timeout for API requests in seconds
12+
RETRY_STRATEGY = {
13+
'total': 3, # Total number of retries to allow
14+
'backoff_factor': 1 # Factor by which the delay between retries will increase
15+
}
16+
TOKEN_REFRESH_THRESHOLD_SECONDS = 300 # Time in seconds before token expiration to attempt refresh
17+
DEBUG_MODE = False
18+
LOGGING_CONFIG = {
19+
'level': 'INFO',
20+
'format': '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
21+
}
22+
APP_KEY = os.getenv('APP_KEY')
23+
APP_SECRET = os.getenv('APP_SECRET')
24+
CALLBACK_URL = os.getenv('CALLBACK_URL')

main.py

Lines changed: 41 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,30 @@
1-
from modules import api, stream
21
from datetime import datetime, timedelta
2+
from api_client import APIClient
3+
from accounts import Accounts
4+
from orders import Orders
5+
from market_data import MarketData
36

47
def main():
8+
client = APIClient() # Initialize the API client
9+
accounts_api = Accounts(client)
10+
orders_api = Orders(client)
11+
market_data_api = MarketData(client)
12+
513
# Get account numbers for linked accounts
6-
print(api.accounts.get_account_numbers().json())
14+
print(accounts_api.get_account_numbers().json())
715

816
# Get positions for linked accounts
9-
print(api.accounts.get_all_accounts().json())
17+
print(accounts_api.get_all_accounts().json())
1018

1119
# Get specific account positions
12-
print(api.accounts.get_account(fields="positions").json())
20+
print(accounts_api.get_account(fields="positions").json())
1321

1422
# Get up to 3000 orders for an account for the past week
15-
print(api.orders.get_orders(3000, datetime.now() - timedelta(days=7), datetime.now()).json())
23+
print(orders_api.get_orders(3000, datetime.now() - timedelta(days=7), datetime.now()).json())
1624

17-
# Place an order (uncomment to test)
25+
# Example to place an order (commented out for safety)
1826
"""
19-
order = {
27+
order_details = {
2028
"orderType": "LIMIT",
2129
"session": "NORMAL",
2230
"duration": "DAY",
@@ -26,64 +34,55 @@ def main():
2634
{"instruction": "BUY", "quantity": 1, "instrument": {"symbol": "INTC", "assetType": "EQUITY"}}
2735
]
2836
}
29-
response = api.orders.place_order(order)
30-
print(f"Place order response: {response}")
31-
order_id = response.headers.get('location', '/').split('/')[-1]
32-
print(f"OrderID: {order_id}")
33-
34-
# Get a specific order
35-
print(api.orders.get_order(order_id).json())
36-
37-
# Cancel specific order
38-
print(api.orders.cancel_order(order_id))
37+
order_response = orders_api.place_order('account_hash', order_details)
38+
print(f"Place order response: {order_response.json()}")
39+
order_id = order_response.headers.get('location', '/').split('/')[-1]
3940
"""
4041

41-
# Replace specific order
42-
# api.orders.replace_order(order_id, order)
42+
# Get a specific order
43+
# print(orders_api.get_order('account_hash', order_id).json())
4344

4445
# Get up to 3000 orders for all accounts for the past week
45-
print(api.orders.get_all_orders(3000, datetime.now() - timedelta(days=7), datetime.now()).json())
46+
print(orders_api.get_orders(3000, datetime.now() - timedelta(days=7), datetime.now()).json())
4647

4748
# Get all transactions for an account
48-
print(api.transactions.get_transactions(datetime.now() - timedelta(days=7), datetime.now(), "TRADE").json())
49+
print(accounts_api.get_account_transactions('account_hash', datetime.now() - timedelta(days=7), datetime.now(), "TRADE").json())
4950

5051
# Get user preferences for an account
51-
print(api.user_preference.get_user_preference().json())
52+
print(accounts_api.get_user_preferences('account_hash').json())
53+
54+
# Market-data-related requests
55+
quotes = market_data_api.Quotes(market_data_api)
56+
options = market_data_api.Options(market_data_api)
57+
price_history = market_data_api.PriceHistory(market_data_api)
58+
movers = market_data_api.Movers(market_data_api)
59+
market_hours = market_data_api.MarketHours(market_data_api)
60+
instruments = market_data_api.Instruments(market_data_api)
5261

5362
# Get a list of quotes
54-
print(api.quotes.get_list(["AAPL", "AMD"]).json())
63+
print(quotes.get_list(["AAPL", "AMD"]).json())
5564

5665
# Get a single quote
57-
print(api.quotes.get_single("INTC").json())
66+
print(quotes.get_single("INTC").json())
5867

5968
# Get an option expiration chain
60-
print(api.options.get_expiration_chain("AAPL").json())
69+
print(options.get_chains("AAPL").json())
6170

6271
# Get movers for an index
63-
print(api.movers.get_movers("$DJI").json())
72+
print(movers.get_movers("$DJI").json())
6473

6574
# Get market hours for symbols
66-
print(api.market_hours.get_by_markets("equity,option").json())
75+
print(market_hours.by_markets("equity,option").json())
6776

6877
# Get market hours for a market
69-
print(api.market_hours.get_by_market("equity").json())
78+
print(market_hours.by_market("equity").json())
7079

7180
# Get instruments for a symbol
72-
print(api.instruments.get_by_symbol("AAPL", "search").json())
81+
print(instruments.by_symbol("AAPL", "search").json())
7382

7483
# Get instruments for a CUSIP
75-
print(api.instruments.get_by_cusip("037833100").json()) # 037833100 = AAPL
76-
77-
# Send a subscription request to the stream (uncomment if you start the stream below)
78-
"""
79-
stream.send(stream.utilities.basic_request("CHART_EQUITY", "SUBS", parameters={"keys": "AMD,INTC", "fields": "0,1,2,3,4,5,6,7,8"}))
80-
# Stop the stream after 30s
81-
stream.stop()
82-
"""
84+
print(instruments.by_cusip("037833100").json()) # 037833100 = AAPL
8385

8486
if __name__ == '__main__':
85-
print("Welcome to the unofficial Schwab API interface!\nGitHub: https://github.com/tylerebowers/Schwab-API-Python")
86-
api.initialize() # checks tokens & loads variables
87-
api.update_tokens_automatically() # starts thread to update tokens automatically
88-
# stream.start_manual() # start the stream manually
89-
main() # call the user code above
87+
print("Welcome to the unofficial Schwab API interface!\nGitHub: https://github.com/Patch-Code-Prosperity/Pythonic-Schwab-API")
88+
main()

0 commit comments

Comments
 (0)