1111
1212class StreamClient :
1313 def __init__ (self , client : APIClient ):
14+ self .client = client
1415 self .websocket = None
1516 self .streamer_info = None
1617 self .start_timestamp = None
1718 self .terminal = MultiTerminal (title = "Stream Output" )
1819 self .color_print = ColorPrint ()
1920 self .active = False
20- self .request_id = 0
21- self .client = client
21+ self .login_successful = False
22+ self .request_id = - 1
2223
2324 async def start (self ):
2425 response = self .client .get_user_preferences ()
25- if not response :
26- self .color_print .print ("error" , f"Failed to get streamer info: { response . text } " )
26+ if 'error' in response : # Assuming error handling is done inside get_user_preferences
27+ self .color_print .print ("error" , f"Failed to get streamer info: { response [ 'error' ] } " )
2728 exit (1 )
2829 self .streamer_info = response ['streamerInfo' ][0 ]
2930 login = self ._construct_login_message ()
30- self .color_print .print ("info" , "Starting stream..." )
31- self .color_print .print ("info" , f"Streamer info: { self .streamer_info } " )
32- self .color_print .print ("info" , f"Login message: { login } " )
33- while True :
34- try :
35- await self ._connect_and_stream (login )
36- except websockets .exceptions .ConnectionClosedOK :
37- self .color_print .print ("info" , "Stream has closed." )
38- break
39- except Exception as e :
40- self .color_print .print ("error" , f"{ e } " )
41- self ._handle_stream_error (e )
42-
31+ await self .connect ()
32+ await self .send (login )
33+
34+ async def connect (self ):
35+ try :
36+ self .websocket = await websockets .connect (self .streamer_info .get ('streamerSocketUrl' ))
37+ self .active = True
38+ self .color_print .print ("info" , "Connection established." )
39+ except Exception as e :
40+ self .color_print .print ("error" , f"Failed to connect: { e } " )
41+
42+ async def send (self , message ):
43+ if not self .active :
44+ await self .connect ()
45+ try :
46+ await self .websocket .send (json .dumps (message ))
47+ self .color_print .print ("info" , f"Message sent: { json .dumps (message )} " )
48+ response = await self .websocket .recv ()
49+ await self .handle_response (response )
50+ except Exception as e :
51+ self .color_print .print ("error" , f"Failed to send message: { e } " )
52+
53+ async def handle_response (self , message ):
54+ message = json .loads (message )
55+ self .color_print .print ("info" , f"Received: { message } " )
56+ if "Login" in message .get ('command' , '' ) and message .get ('content' , {}).get ('code' ) == 0 :
57+ self .login_successful = True
58+ self .color_print .print ("info" , "Login successful." )
59+
60+ async def receive (self ):
61+ try :
62+ return await self .websocket .recv ()
63+ except Exception as e :
64+ self .color_print .print ("error" , f"Error receiving message: { e } " )
65+ return None
66+
4367 def _construct_login_message (self ):
68+ # Increment request ID for each new request
4469 self .request_id += 1
45- return basic_request ("ADMIN" , "LOGIN" , self .request_id , {
70+
71+ # Prepare the parameters dictionary specifically for the parameters that need to be nested under 'parameters'
72+ parameters = {
4673 "Authorization" : self .client .token_info .get ("access_token" ),
4774 "SchwabClientChannel" : self .streamer_info .get ("schwabClientChannel" ),
48- "SchwabClientFunctionId" : self .streamer_info .get ("schwabClientFunctionId" ),
49- "SchwabClientCustomerId" : self .streamer_info .get ("schwabClientCustomerId" ),
50- "SchwabClientCorrelId" : self .streamer_info .get ("schwabClientCorrelId" )
51- })
75+ "SchwabClientFunctionId" : self .streamer_info .get ("schwabClientFunctionId" )
76+ }
77+
78+ # Call the basic_request function with customer ID and correlation ID at the top level of the request
79+ return basic_request (
80+ service = "ADMIN" ,
81+ request_id = self .request_id ,
82+ command = "LOGIN" ,
83+ customer_id = self .streamer_info .get ("schwabClientCustomerId" ),
84+ correl_id = self .streamer_info .get ("schwabClientCorrelId" ),
85+ parameters = parameters
86+ )
5287
5388 async def _connect_and_stream (self , login ):
54- self .start_timestamp = datetime .now ()
55- self .color_print .print ("info" , "Connecting to server..." )
56- self .color_print .print ("info" , f"Start timestamp: { self .start_timestamp } " )
57- self .color_print .print ("info" , f"Streamer socket URL: { self .streamer_info .get ('streamerSocketUrl' )} " )
58- async with websockets .connect (self .streamer_info .get ('streamerSocketUrl' ),
59- ping_interval = None ) as self .websocket :
60- self .terminal .print ("[INFO]: Connecting to server..." )
61- await self .websocket .send (json .dumps (login ))
62- self .terminal .print (f"[Login]: { await self .websocket .recv ()} " )
63- self .active = True
64- while True :
65- received = await self .websocket .recv ()
66- self .terminal .print (received )
89+ try :
90+ async with websockets .connect (self .streamer_info .get ('streamerSocketUrl' )) as websocket :
91+ self .websocket = websocket
92+ await websocket .send (json .dumps (login ))
93+ while True :
94+ message = await websocket .recv ()
95+ await self .handle_message (json .loads (message ))
96+ except websockets .exceptions .ConnectionClosedOK :
97+ self .color_print .print ("info" , "Stream has closed." )
98+ except Exception as e :
99+ self .color_print .print ("error" , f"{ e } " )
100+ self ._handle_stream_error (e )
101+
102+ async def handle_message (self , message ):
103+ if "response" in message and any (
104+ resp .get ("code" ) == "0" for resp in message ["response" ]): # Check if login is successful
105+ self .color_print .print ("info" , "Logged in successfully, sending subscription requests..." )
106+ await self .send_subscription_requests ()
107+ else :
108+ self .color_print .print ("info" , f"Received: { message } " )
109+
110+ async def reconnect (self ):
111+ self .terminal .print ("[INFO]: Attempting to reconnect..." )
112+ try :
113+ await asyncio .sleep (10 ) # Wait before attempting to reconnect
114+ login = self ._construct_login_message () # Reconstruct login info
115+ await self ._connect_and_stream (login ) # Attempt to reconnect
116+ return True
117+ except Exception as e :
118+ self .terminal .print (f"Reconnect failed: { e } " )
119+ return False
67120
68121 def _handle_stream_error (self , error ):
69122 self .active = False
@@ -75,16 +128,8 @@ def _handle_stream_error(self, error):
75128 else :
76129 self .terminal .print ("[WARNING]: Connection lost to server, reconnecting..." )
77130
78- async def send (self , listOfRequests ):
79-
80- if not isinstance (listOfRequests , list ):
81- listOfRequests = [listOfRequests ]
82- if self .active :
83- to_send = json .dumps ({"requests" : listOfRequests })
84- await self .websocket .send (to_send )
85- else :
86- self .color_print .print ("warning" , "Stream is not active, nothing sent." )
87-
88131 def stop (self ):
89- self .send (basic_request ("ADMIN" , "LOGOUT" , self .request_id ))
90- self .active = False
132+ if self .active :
133+ self .active = False
134+ asyncio .create_task (self .websocket .close ())
135+ self .color_print .print ("info" , "Connection closed." )
0 commit comments