1- import os
1+ #
2+ # The MIT License
3+ #
4+ # @copyright Copyright (c) 2017 Intel Corporation
5+ # @copyright Copyright (c) 2021 ApertureData Inc
6+ #
7+ # Permission is hereby granted, free of charge, to any person obtaining a copy
8+ # of this software and associated documentation files (the "Software"),
9+ # to deal in the Software without restriction,
10+ # including without limitation the rights to use, copy, modify,
11+ # merge, publish, distribute, sublicense, and/or sell
12+ # copies of the Software, and to permit persons to whom the Software is
13+ # furnished to do so, subject to the following conditions:
14+ #
15+ # The above copyright notice and this permission notice shall be included in
16+ # all copies or substantial portions of the Software.
17+ #
18+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
19+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
20+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
21+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
22+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,
23+ # ARISING FROM,
24+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
25+ # THE SOFTWARE.
26+ #
27+
228import sys
29+ import os
30+ import socket
31+ import struct
332import time
33+ import json
34+ import ssl
435
5- import vdms
36+ # VDMS Protobuf import (autogenerated)
37+ from . import queryMessage_pb2
638
739BACKOFF_TIME = 0.2 # seconds
840BACKOFF_MULTIPLIER = 1.2
941
42+ PROTOCOL_VERSION = 1
43+
1044class Connector (object ):
1145
12- def __init__ (self , ip = "localhost" , port = 55555 ):
46+ def __init__ (self , host = "localhost" , port = 55555 ,
47+ user = "" , password = "" , token = "" ,
48+ session = None ,
49+ use_ssl = True ):
1350
14- self .ip = ip
15- self .port = port
51+ self .use_ssl = use_ssl
1652
17- self .connector = vdms . vdms ()
18- self .connector . connect ( ip , port )
53+ self .host = host
54+ self .port = port
1955
56+ self .connected = False
57+ self .last_response = ''
2058 self .last_query_time = 0
2159
60+ self .session_token = ""
61+
62+ self ._connect ()
63+
64+ if session :
65+ self .session = session
66+ else :
67+ try :
68+ self ._authenticate (user , password , token )
69+ except Exception as e :
70+ raise Exception ("Authentication failed:" , str (e ))
71+
72+ self .session_token = self .session ["session_token" ]
73+
74+ def _send_msg (self , data ):
75+
76+ sent_len = struct .pack ('@I' , len (data )) # send size first
77+ self .conn .send (sent_len )
78+ self .conn .send (data )
79+
80+ def _recv_msg (self ):
81+
82+ recv_len = self .conn .recv (4 ) # get message size
83+
84+ recv_len = struct .unpack ('@I' , recv_len )[0 ]
85+ response = b''
86+ while len (response ) < recv_len :
87+ packet = self .conn .recv (recv_len - len (response ))
88+ if not packet :
89+ print ("Error receiving" )
90+ response += packet
91+
92+ return response
93+
94+ def _authenticate (self , user , password = "" , token = "" ):
95+
96+ query = [{
97+ "Authenticate" : {
98+ "username" : user ,
99+ }
100+ }]
101+
102+ if password :
103+ query [0 ]["Authenticate" ]["password" ] = password
104+ elif token :
105+ query [0 ]["Authenticate" ]["token" ] = token
106+ else :
107+ raise Exception ("Either password or token must be specified for authentication" )
108+
109+ response , _ = self .query (query )
110+
111+ self .session = response [0 ]["Authenticate" ]
112+ if self .session ["status" ] != 0 :
113+ raise Exception (self .session ["info" ])
114+
115+ self .session_token = self .session ["session_token" ]
116+
117+ def _connect (self ):
118+
119+ self .conn = socket .socket (socket .AF_INET , socket .SOCK_STREAM )
120+ self .conn .setsockopt (socket .SOL_TCP , socket .TCP_NODELAY , 1 )
121+
122+ # TCP_QUICKACK only supported in Linux 2.4.4+.
123+ # We use startswith for checking the platform following Python's
124+ # documentation:
125+ # https://docs.python.org/dev/library/sys.html#sys.platform
126+ if sys .platform .startswith ('linux' ):
127+ self .conn .setsockopt (socket .SOL_TCP , socket .TCP_QUICKACK , 1 )
128+
129+ self .conn .connect ((self .host , self .port ))
130+
131+ # Handshake with server to negotiate protocol
132+
133+ protocol = 2 if self .use_ssl else 1
134+
135+ hello_msg = struct .pack ('@II' , PROTOCOL_VERSION , protocol )
136+
137+ # Send desire protocol
138+ self ._send_msg (hello_msg )
139+
140+ # Receive response from server
141+ response = self ._recv_msg ()
142+
143+ version , server_protocol = struct .unpack ('@II' , response )
144+
145+ if version != PROTOCOL_VERSION :
146+ print ("WARNING: Protocol version differ from server" )
147+
148+ if server_protocol != protocol :
149+ self .conn .close ()
150+ self .connected = False
151+ raise Exception ("Server did not accept protocol. Aborting Connection." )
152+
153+ if self .use_ssl :
154+
155+ # Server is ok with SSL, we switch over SSL.
156+ self .context = ssl .SSLContext (ssl .PROTOCOL_TLS_CLIENT )
157+ self .context .check_hostname = False
158+ # TODO, we need to add support for local certificates
159+ # For now, we let the server send us the certificate
160+ self .context .verify_mode = ssl .VerifyMode .CERT_NONE
161+ self .conn = self .context .wrap_socket (self .conn )
162+
163+ self .connected = True
164+
22165 def __del__ (self ):
23166
24- self .connector .disconnect ()
167+ self .conn .close ()
168+ self .connected = False
25169
26170 def create_new_connection (self ):
27171
28- return Connector (self .ip , self .port )
172+ return Connector (self .host , self .port , session = self .session )
173+
174+ def _query (self , query , blob_array = []):
175+
176+ # Check the query type
177+ if not isinstance (query , str ): # assumes json
178+ query_str = json .dumps (query )
179+ else :
180+ query_str = query
29181
182+ if not self .connected :
183+ return "NOT CONNECTED"
184+
185+ query_msg = queryMessage_pb2 .queryMessage ()
186+ # query has .json and .blobs
187+ query_msg .json = query_str
188+
189+ # Set Auth token, only when not authenticated before
190+ query_msg .token = self .session_token
191+
192+ for blob in blob_array :
193+ query_msg .blobs .append (blob )
194+
195+ # Serialize with protobuf and send
196+ data = query_msg .SerializeToString ();
197+ self ._send_msg (data )
198+
199+ response = self ._recv_msg ()
200+
201+ querRes = queryMessage_pb2 .queryMessage ()
202+ querRes .ParseFromString (response )
203+
204+ response_blob_array = []
205+ for b in querRes .blobs :
206+ response_blob_array .append (b )
207+
208+ self .last_response = json .loads (querRes .json )
209+
210+ return (self .last_response , response_blob_array )
211+
212+ # This is the API method, that has a retry mechanism
30213 def query (self , q , blobs = [], n_retries = 0 ):
31214
32215 if n_retries == 0 :
33216 start = time .time ()
34- self .response , self .blobs = self .connector . query (q , blobs )
217+ self .response , self .blobs = self ._query (q , blobs )
35218 self .last_query_time = time .time () - start
36219 return self .response , self .blobs
37220
@@ -49,15 +232,15 @@ def query(self, q, blobs=[], n_retries=0):
49232 error_msg += " retries\n "
50233 sys .stderr .write (error_msg )
51234 sys .stderr .write ("Response: \n " )
52- sys .stderr .write (self .connector . get_last_response_str ())
235+ sys .stderr .write (self .get_last_response_str ())
53236 sys .stderr .write ("\n " )
54237 sys .stderr .write ("Query: \n " )
55238 sys .stderr .write (str (q ))
56239 sys .stderr .write ("\n " )
57240 break
58241
59242 start = time .time ()
60- self .response , self .blobs = self .connector . query (q , blobs )
243+ self .response , self .blobs = self ._query (q , blobs )
61244 self .last_query_time = time .time () - start
62245 status = self .check_status (self .response )
63246
@@ -67,7 +250,11 @@ def query(self, q, blobs=[], n_retries=0):
67250
68251 def get_last_response_str (self ):
69252
70- return self .connector .get_last_response_str ()
253+ return json .dumps (self .last_response , indent = 4 , sort_keys = False )
254+
255+ def print_last_response (self ):
256+
257+ print (self .get_last_response_str ())
71258
72259 def get_last_query_time (self ):
73260
0 commit comments