1212
1313from .rest import ApiException
1414
15+ import select
1516import certifi
17+ import time
1618import collections
17- import websocket
19+ from websocket import WebSocket , ABNF , enableTrace
1820import six
1921import ssl
2022from six .moves .urllib .parse import urlencode
2123from six .moves .urllib .parse import quote_plus
2224
25+ STDIN_CHANNEL = 0
26+ STDOUT_CHANNEL = 1
27+ STDERR_CHANNEL = 2
28+
2329
2430class WSClient :
2531 def __init__ (self , configuration , url , headers ):
26- self .messages = []
27- self .errors = []
28- websocket .enableTrace (False )
29- header = None
32+ """A websocket client with support for channels.
33+
34+ Exec command uses different channels for different streams. for
35+ example, 0 is stdin, 1 is stdout and 2 is stderr. Some other API calls
36+ like port forwarding can forward different pods' streams to different
37+ channels.
38+ """
39+ enableTrace (False )
40+ header = []
41+ self ._connected = False
42+ self ._channels = {}
43+ self ._all = ""
3044
3145 # We just need to pass the Authorization, ignore all the other
3246 # http headers we get from the generated code
33- if 'Authorization' in headers :
34- header = "Authorization: %s" % headers ['Authorization' ]
35-
36- self .ws = websocket .WebSocketApp (url ,
37- on_message = self .on_message ,
38- on_error = self .on_error ,
39- on_close = self .on_close ,
40- header = [header ] if header else None )
41- self .ws .on_open = self .on_open
47+ if headers and 'authorization' in headers :
48+ header .append ("authorization: %s" % headers ['authorization' ])
4249
4350 if url .startswith ('wss://' ) and configuration .verify_ssl :
4451 ssl_opts = {
@@ -52,30 +59,145 @@ def __init__(self, configuration, url, headers):
5259 else :
5360 ssl_opts = {'cert_reqs' : ssl .CERT_NONE }
5461
55- self .ws .run_forever (sslopt = ssl_opts )
56-
57- def on_message (self , ws , message ):
58- if message [0 ] == '\x01 ' :
59- message = message [1 :]
60- if message :
61- if six .PY3 and isinstance (message , six .binary_type ):
62- message = message .decode ('utf-8' )
63- self .messages .append (message )
64-
65- def on_error (self , ws , error ):
66- self .errors .append (error )
67-
68- def on_close (self , ws ):
69- pass
70-
71- def on_open (self , ws ):
72- pass
62+ self .sock = WebSocket (sslopt = ssl_opts , skip_utf8_validation = False )
63+ self .sock .connect (url , header = header )
64+ self ._connected = True
65+
66+ def peek_channel (self , channel , timeout = 0 ):
67+ """Peek a channel and return part of the input,
68+ empty string otherwise."""
69+ self .update (timeout = timeout )
70+ if channel in self ._channels :
71+ return self ._channels [channel ]
72+ return ""
73+
74+ def read_channel (self , channel , timeout = 0 ):
75+ """Read data from a channel."""
76+ if channel not in self ._channels :
77+ ret = self .peek_channel (channel , timeout )
78+ else :
79+ ret = self ._channels [channel ]
80+ if channel in self ._channels :
81+ del self ._channels [channel ]
82+ return ret
83+
84+ def readline_channel (self , channel , timeout = None ):
85+ """Read a line from a channel."""
86+ if timeout is None :
87+ timeout = float ("inf" )
88+ start = time .time ()
89+ while self .is_open () and time .time () - start < timeout :
90+ if channel in self ._channels :
91+ data = self ._channels [channel ]
92+ if "\n " in data :
93+ index = data .find ("\n " )
94+ ret = data [:index ]
95+ data = data [index + 1 :]
96+ if data :
97+ self ._channels [channel ] = data
98+ else :
99+ del self ._channels [channel ]
100+ return ret
101+ self .update (timeout = (timeout - time .time () + start ))
102+
103+ def write_channel (self , channel , data ):
104+ """Write data to a channel."""
105+ self .sock .send (chr (channel ) + data )
106+
107+ def peek_stdout (self , timeout = 0 ):
108+ """Same as peek_channel with channel=1."""
109+ return self .peek_channel (STDOUT_CHANNEL , timeout = timeout )
110+
111+ def read_stdout (self , timeout = None ):
112+ """Same as read_channel with channel=1."""
113+ return self .read_channel (STDOUT_CHANNEL , timeout = timeout )
114+
115+ def readline_stdout (self , timeout = None ):
116+ """Same as readline_channel with channel=1."""
117+ return self .readline_channel (STDOUT_CHANNEL , timeout = timeout )
118+
119+ def peek_stderr (self , timeout = 0 ):
120+ """Same as peek_channel with channel=2."""
121+ return self .peek_channel (STDERR_CHANNEL , timeout = timeout )
122+
123+ def read_stderr (self , timeout = None ):
124+ """Same as read_channel with channel=2."""
125+ return self .read_channel (STDERR_CHANNEL , timeout = timeout )
126+
127+ def readline_stderr (self , timeout = None ):
128+ """Same as readline_channel with channel=2."""
129+ return self .readline_channel (STDERR_CHANNEL , timeout = timeout )
130+
131+ def read_all (self ):
132+ """Read all of the inputs with the same order they recieved. The channel
133+ information would be part of the string. This is useful for
134+ non-interactive call where a set of command passed to the API call and
135+ their result is needed after the call is concluded.
136+
137+ TODO: Maybe we can process this and return a more meaningful map with
138+ channels mapped for each input.
139+ """
140+ out = self ._all
141+ self ._all = ""
142+ self ._channels = {}
143+ return out
144+
145+ def is_open (self ):
146+ """True if the connection is still alive."""
147+ return self ._connected
148+
149+ def write_stdin (self , data ):
150+ """The same as write_channel with channel=0."""
151+ self .write_channel (STDIN_CHANNEL , data )
152+
153+ def update (self , timeout = 0 ):
154+ """Update channel buffers with at most one complete frame of input."""
155+ if not self .is_open ():
156+ return
157+ if not self .sock .connected :
158+ self ._connected = False
159+ return
160+ r , _ , _ = select .select (
161+ (self .sock .sock , ), (), (), timeout )
162+ if r :
163+ op_code , frame = self .sock .recv_data_frame (True )
164+ if op_code == ABNF .OPCODE_CLOSE :
165+ self ._connected = False
166+ return
167+ elif op_code == ABNF .OPCODE_BINARY or op_code == ABNF .OPCODE_TEXT :
168+ data = frame .data
169+ if six .PY3 :
170+ data = data .decode ("utf-8" )
171+ self ._all += data
172+ if len (data ) > 1 :
173+ channel = ord (data [0 ])
174+ data = data [1 :]
175+ if data :
176+ if channel not in self ._channels :
177+ self ._channels [channel ] = data
178+ else :
179+ self ._channels [channel ] += data
180+
181+ def run_forever (self , timeout = None ):
182+ """Wait till connection is closed or timeout reached. Buffer any input
183+ received during this time."""
184+ if timeout :
185+ start = time .time ()
186+ while self .is_open () and time .time () - start < timeout :
187+ self .update (timeout = (timeout - time .time () + start ))
188+ else :
189+ while self .is_open ():
190+ self .update (timeout = None )
73191
74192
75193WSResponse = collections .namedtuple ('WSResponse' , ['data' ])
76194
77195
78- def GET (configuration , url , query_params , _request_timeout , headers ):
196+ def websocket_call (configuration , url , query_params , _request_timeout ,
197+ _preload_content , headers ):
198+ """An internal function to be called in api-client when a websocket
199+ connection is required."""
200+
79201 # switch protocols from http to websocket
80202 url = url .replace ('http://' , 'ws://' )
81203 url = url .replace ('https://' , 'wss://' )
@@ -105,10 +227,11 @@ def GET(configuration, url, query_params, _request_timeout, headers):
105227 else :
106228 url += '&command=' + quote_plus (commands )
107229
108- client = WSClient (configuration , url , headers )
109- if client .errors :
110- raise ApiException (
111- status = 0 ,
112- reason = '\n ' .join ([str (error ) for error in client .errors ])
113- )
114- return WSResponse ('%s' % '' .join (client .messages ))
230+ try :
231+ client = WSClient (configuration , url , headers )
232+ if not _preload_content :
233+ return client
234+ client .run_forever (timeout = _request_timeout )
235+ return WSResponse ('%s' % '' .join (client .read_all ()))
236+ except (Exception , KeyboardInterrupt , SystemExit ) as e :
237+ raise ApiException (status = 0 , reason = str (e ))
0 commit comments