Skip to content

Commit 39122aa

Browse files
authored
Merge pull request #7 from Patch-Code-Prosperity/streaming
Streaming
2 parents 66265f5 + 1c8b045 commit 39122aa

File tree

7 files changed

+222
-13
lines changed

7 files changed

+222
-13
lines changed

api_client.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,8 @@ def post_token_request(self, data):
6767
'Content-Type': 'application/x-www-form-urlencoded'
6868
}
6969
response = self.session.post(f"{self.config.API_BASE_URL}/v1/oauth/token", headers=headers, data=data)
70-
if response.ok:
70+
if response.status_code == 200:
7171
self.save_token(response.json())
72-
self.load_token()
7372
self.logger.info("Tokens successfully updated.")
7473
return True
7574
else:
@@ -86,7 +85,7 @@ def refresh_access_token(self):
8685
if not self.post_token_request(data):
8786
self.logger.error("Failed to refresh access token.")
8887
return False
89-
88+
self.token_info = self.load_token()
9089
return self.validate_token()
9190

9291
def save_token(self, token_data):
@@ -106,11 +105,17 @@ def load_token(self):
106105
self.logger.warning(f"Loading token failed: {e}")
107106
return None
108107

109-
def validate_token(self):
108+
def validate_token(self, force=False):
110109
""" Validate the current token's validity. """
110+
print(self.token_info['expires_at'])
111+
print(datetime.now())
112+
print(datetime.fromisoformat(self.token_info['expires_at']))
113+
print(datetime.now() < datetime.fromisoformat(self.token_info['expires_at']))
111114
if self.token_info and datetime.now() < datetime.fromisoformat(self.token_info['expires_at']):
115+
print(f"Token expires in {datetime.fromisoformat(self.token_info['expires_at']) - datetime.now()} seconds")
112116
return True
113-
else:
117+
elif force:
118+
print("Token expired or invalid.")
114119
# get AAPL to validate token
115120
params = {'symbol': 'AAPL'}
116121
response = self.make_request(endpoint=f"{self.config.MARKET_DATA_BASE_URL}/chains", params=params, validating=True)
@@ -146,3 +151,11 @@ def make_request(self, endpoint, method="GET", **kwargs):
146151
response = self.session.request(method, url, headers=headers, **kwargs)
147152
response.raise_for_status()
148153
return response.json()
154+
155+
def get_user_preferences(self):
156+
"""Retrieve user preferences."""
157+
try:
158+
return self.make_request(f'{self.config.TRADER_BASE_URL}/userPreference')
159+
except Exception as e:
160+
self.logger.error(f"Failed to get user preferences: {e}")
161+
return None

color_print.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,11 @@ def print(message_type, message, end="\n"):
1515
def input(message):
1616
return input(f"{ColorPrint.COLORS['input']}{message}")
1717

18+
@staticmethod
19+
def user_input(message):
20+
return input(f"{ColorPrint.COLORS['user']}{message}")
21+
22+
1823

1924
if __name__ == '__main__':
2025
ColorPrint.print('info', 'This is an informational message')

config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class APIConfig:
1010
ACCOUNTS_BASE_URL = f"{TRADER_BASE_URL}/accounts"
1111
MARKET_DATA_BASE_URL = f"{API_BASE_URL}/marketdata/v1"
1212
ORDERS_BASE_URL = ACCOUNTS_BASE_URL
13+
STREAMER_INFO_URL = f"{API_BASE_URL}/streamer-info"
1314
REQUEST_TIMEOUT = 30 # Timeout for API requests in seconds
1415
RETRY_STRATEGY = {
1516
'total': 3, # Total number of retries to allow

main.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,39 @@
1+
import asyncio
12
from datetime import datetime, timedelta
23
from api_client import APIClient
34
from accounts import Accounts
45
from market_data import Quotes, Options, PriceHistory, Movers, MarketHours, Instruments
56
from orders import Orders
7+
from stream_client import StreamClient
8+
from asyncio import get_event_loop
9+
import stream_utilities
610

711

12+
async def main_stream():
13+
client = APIClient() # Initialize the API client
14+
stream_client = StreamClient(client)
15+
await stream_client.start() # Start and connect
16+
17+
while stream_client.active:
18+
# Construct and send a subscription request
19+
request = stream_utilities.basic_request(
20+
"LEVELONE_EQUITIES",
21+
request_id=stream_client.request_id,
22+
command="SUBS",
23+
customer_id=stream_client.streamer_info.get("schwabClientCustomerId"),
24+
correl_id=stream_client.streamer_info.get("schwabClientCorrelId"),
25+
parameters={
26+
"keys": "TSLA,AMZN,AAPL,NFLX,BABA",
27+
"fields": "0,1,2,3,4,5,8,9,12,13,15,24,28,29,30,31,48"
28+
}
29+
)
30+
await stream_client.send(request)
31+
message = await stream_client.receive()
32+
print(f"Received: {message}")
33+
await asyncio.sleep(1) # Delay between messages
34+
35+
stream_client.stop()
36+
837
def main():
938
client = APIClient() # Initialize the API client
1039
accounts_api = Accounts(client)
@@ -86,6 +115,8 @@ def main():
86115

87116

88117
if __name__ == '__main__':
89-
print(
90-
"Welcome to the unofficial Schwab API interface!\nGitHub: https://github.com/Patch-Code-Prosperity/Pythonic-Schwab-API")
91-
main()
118+
print("Welcome to the unofficial Schwab API interface!\n"
119+
"GitHub: https://github.com/Patch-Code-Prosperity/Pythonic-Schwab-API")
120+
loop = get_event_loop()
121+
loop.run_until_complete(main_stream())
122+
# main()

multi_terminal.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, title="Terminal", height=20, width=200, font=("Courier New",
1515
self.textColor = textColor
1616
self.allowClosing = allowClosing
1717
self.ignoreClosedPrints = ignoreClosedPrints
18-
self.isOpen = False
18+
self.is_open = False
1919
self.start()
2020

2121
def run(self):
@@ -25,16 +25,16 @@ def run(self):
2525
self.text_box = tk.Text(self.root, height=self.height, width=self.width, font=self.font,
2626
bg=self.backgroundColor, fg=self.textColor, state='disabled')
2727
self.text_box.pack(side="left", fill="both", expand=True)
28-
self.isOpen = True
28+
self.is_open = True
2929
self.root.mainloop()
3030

3131
def close(self):
32-
if self.isOpen:
33-
self.isOpen = False
32+
if self.is_open:
33+
self.is_open = False
3434
self.root.destroy()
3535

3636
def print(self, text, end="\n"):
37-
if not self.isOpen:
37+
if not self.is_open:
3838
if not self.ignoreClosedPrints:
3939
raise Exception(f"Terminal '{self.title}' is closed.")
4040
return

stream_client.py

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
import json
2+
import asyncio
3+
import threading
4+
import websockets
5+
from datetime import datetime, time
6+
from multi_terminal import MultiTerminal
7+
from api_client import APIClient
8+
from stream_utilities import basic_request # Importing utility functions
9+
from color_print import ColorPrint
10+
11+
12+
class StreamClient:
13+
def __init__(self, client: APIClient):
14+
self.client = client
15+
self.websocket = None
16+
self.streamer_info = None
17+
self.start_timestamp = None
18+
self.terminal = MultiTerminal(title="Stream Output")
19+
self.color_print = ColorPrint()
20+
self.active = False
21+
self.login_successful = False
22+
self.request_id = -1
23+
24+
async def start(self):
25+
response = self.client.get_user_preferences()
26+
if 'error' in response: # Assuming error handling is done inside get_user_preferences
27+
self.color_print.print("error", f"Failed to get streamer info: {response['error']}")
28+
exit(1)
29+
self.streamer_info = response['streamerInfo'][0]
30+
login = self._construct_login_message()
31+
await self.connect()
32+
await self.send(login)
33+
34+
async def connect(self):
35+
try:
36+
self.websocket = await websockets.connect(self.streamer_info.get('streamerSocketUrl'))
37+
self.active = True
38+
self.color_print.print("info", "Connection established.")
39+
except Exception as e:
40+
self.color_print.print("error", f"Failed to connect: {e}")
41+
42+
async def send(self, message):
43+
if not self.active:
44+
await self.connect()
45+
try:
46+
await self.websocket.send(json.dumps(message))
47+
self.color_print.print("info", f"Message sent: {json.dumps(message)}")
48+
response = await self.websocket.recv()
49+
await self.handle_response(response)
50+
except Exception as e:
51+
self.color_print.print("error", f"Failed to send message: {e}")
52+
53+
async def handle_response(self, message):
54+
message = json.loads(message)
55+
self.color_print.print("info", f"Received: {message}")
56+
if "Login" in message.get('command', '') and message.get('content', {}).get('code') == 0:
57+
self.login_successful = True
58+
self.color_print.print("info", "Login successful.")
59+
60+
async def receive(self):
61+
try:
62+
return await self.websocket.recv()
63+
except Exception as e:
64+
self.color_print.print("error", f"Error receiving message: {e}")
65+
return None
66+
67+
def _construct_login_message(self):
68+
# Increment request ID for each new request
69+
self.request_id += 1
70+
71+
# Prepare the parameters dictionary specifically for the parameters that need to be nested under 'parameters'
72+
parameters = {
73+
"Authorization": self.client.token_info.get("access_token"),
74+
"SchwabClientChannel": self.streamer_info.get("schwabClientChannel"),
75+
"SchwabClientFunctionId": self.streamer_info.get("schwabClientFunctionId")
76+
}
77+
78+
# Call the basic_request function with customer ID and correlation ID at the top level of the request
79+
return basic_request(
80+
service="ADMIN",
81+
request_id=self.request_id,
82+
command="LOGIN",
83+
customer_id=self.streamer_info.get("schwabClientCustomerId"),
84+
correl_id=self.streamer_info.get("schwabClientCorrelId"),
85+
parameters=parameters
86+
)
87+
88+
async def _connect_and_stream(self, login):
89+
try:
90+
async with websockets.connect(self.streamer_info.get('streamerSocketUrl')) as websocket:
91+
self.websocket = websocket
92+
await websocket.send(json.dumps(login))
93+
while True:
94+
message = await websocket.recv()
95+
await self.handle_message(json.loads(message))
96+
except websockets.exceptions.ConnectionClosedOK:
97+
self.color_print.print("info", "Stream has closed.")
98+
except Exception as e:
99+
self.color_print.print("error", f"{e}")
100+
self._handle_stream_error(e)
101+
102+
async def handle_message(self, message):
103+
if "response" in message and any(
104+
resp.get("code") == "0" for resp in message["response"]): # Check if login is successful
105+
self.color_print.print("info", "Logged in successfully, sending subscription requests...")
106+
else:
107+
self.color_print.print("info", f"Received: {message}")
108+
109+
async def reconnect(self):
110+
self.terminal.print("[INFO]: Attempting to reconnect...")
111+
try:
112+
await asyncio.sleep(10) # Wait before attempting to reconnect
113+
login = self._construct_login_message() # Reconstruct login info
114+
await self._connect_and_stream(login) # Attempt to reconnect
115+
return True
116+
except Exception as e:
117+
self.terminal.print(f"Reconnect failed: {e}")
118+
return False
119+
120+
def _handle_stream_error(self, error):
121+
self.active = False
122+
if isinstance(error, RuntimeError) and str(error) == "Streaming window has been closed":
123+
self.color_print.print("warning", "Streaming window has been closed.")
124+
else:
125+
if (datetime.now() - self.start_timestamp).seconds < 70:
126+
self.color_print.print("error", "Stream not alive for more than 1 minute, exiting...")
127+
else:
128+
self.terminal.print("[WARNING]: Connection lost to server, reconnecting...")
129+
130+
def stop(self):
131+
if self.active:
132+
self.active = False
133+
asyncio.create_task(self.websocket.close())
134+
self.color_print.print("info", "Connection closed.")

stream_utilities.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
def basic_request(service, request_id, command, customer_id, correl_id, parameters=None):
2+
"""Constructs a basic request dictionary for streaming commands.
3+
4+
Args:
5+
service (str): The service name, e.g., 'ADMIN'.
6+
request_id (int): The identifier for this request.
7+
command (str): The command to be executed, e.g., 'LOGIN'.
8+
customer_id (str): The Schwab client customer ID.
9+
correl_id (str): The Schwab client correlation ID.
10+
parameters (dict, optional): Additional parameters for the command.
11+
12+
Returns:
13+
dict: The request dictionary.
14+
"""
15+
request = {
16+
"service": service.upper(),
17+
"requestid": str(request_id),
18+
"command": command.upper(),
19+
"SchwabClientCustomerId": customer_id,
20+
"SchwabClientCorrelId": correl_id
21+
}
22+
# Include parameters if provided
23+
if parameters:
24+
request["parameters"] = parameters
25+
return request

0 commit comments

Comments
 (0)