@@ -17,22 +17,26 @@ def __init__(
1717 service_port = None ,
1818 service_host = None ,
1919 login_node = None ,
20+ unix_socketfile = None ,
2021 debug = False ,
2122 ):
2223 """Creates a new Forwarder instance
23- The remote service host/port can be already encoded in the endpoint, or given separately
24+ The remote service address (host/port or domain socket filename) can be
25+ already encoded in the endpoint, or given separately
2426
2527 Args:
2628 transport: the transport (security sessions should be OFF)
2729 endpoint: UNICORE REST API endpoint which can establish the forwarding
2830 service_port: the remote service port (if not already encoded in the endpoint)
2931 service_host: the (optional) remote service host (if not encoded in the endpoint)
3032 login_node: the /optional) login node to use (if not encoded in the endpoint)
33+ unix_socketfile: the name of the remote service's UNIX domain socket file
34+ (in the job working directory)
3135 debug: set to True for some debug output to the console
3236 """
3337 self .endpoint = endpoint
3438 self .parsed_url = _parse_forwarding_params (
35- self .endpoint , service_port , service_host , login_node
39+ self .endpoint , service_port , service_host , login_node , unix_socketfile
3640 )
3741 self .transport = transport
3842 self .quiet = not debug
@@ -101,10 +105,10 @@ def create_socket(self):
101105 def start_forwarding (self ):
102106 self .quiet or print ("Start forwarding." )
103107 threading .Thread (
104- target = self .transfer , args = (self .client_socket , self .service_socket )
108+ target = self .transfer , args = (self .client_socket , self .service_socket , "local->remote" )
105109 ).start ()
106110 threading .Thread (
107- target = self .transfer , args = (self .service_socket , self .client_socket )
111+ target = self .transfer , args = (self .service_socket , self .client_socket , "remote->local" )
108112 ).start ()
109113
110114 def stop_forwarding (self ):
@@ -119,8 +123,8 @@ def stop_forwarding(self):
119123 except OSError :
120124 pass
121125
122- def transfer (self , source , destination ):
123- desc = f"{ source .getpeername ()} --> { destination .getpeername ()} "
126+ def transfer (self , source , destination , name = "" ):
127+ desc = f"{ name } : { source .getpeername ()} --> { destination .getpeername ()} "
124128 self .quiet or print ("Start TCP forwarding %s" % desc )
125129 buf_size = 32768
126130 while True :
@@ -132,7 +136,7 @@ def transfer(self, source, destination):
132136 self .quiet or print ("Source is at EOF for %s" % desc )
133137 break
134138 except OSError as e :
135- self .quiet or print ("I/O ERROR for %s " % desc , e )
139+ self .quiet or print ("I/O ERROR for %s" % desc , e )
136140 for s in source , destination :
137141 try :
138142 s .close ()
@@ -141,7 +145,7 @@ def transfer(self, source, destination):
141145 break
142146 self .quiet or print ("Stopping TCP forwarding %s" % desc )
143147
144- def run (self , local_port ):
148+ def run (self , local_port , keep_alive = False ):
145149 """open a listener, accept client connections and forward them to the backend"""
146150 with socket .socket () as server :
147151 server .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
@@ -156,9 +160,13 @@ def run(self, local_port):
156160 self .quiet or print ("Client %s connected." % str (self .client_socket .getpeername ()))
157161 self .service_socket = self .connect ()
158162 self .start_forwarding ()
163+ if not keep_alive :
164+ break
159165
160166
161- def _parse_forwarding_params (endpoint , service_port = None , service_host = None , login_node = None ):
167+ def _parse_forwarding_params (
168+ endpoint , service_port = None , service_host = None , login_node = None , unix_socketfile = None
169+ ):
162170 """If not already present in the endpoint, the parameters like
163171 service_port are added.
164172
@@ -167,6 +175,8 @@ def _parse_forwarding_params(endpoint, service_port=None, service_host=None, log
167175 """
168176 parsed_url = urlparse (endpoint )
169177 q = parsed_url .query
178+ if (unix_socketfile is not None ) and (service_port is not None ):
179+ raise ValueError ("Only one of 'file' and 'service_host'/'service_port' is allowed" )
170180 if service_port is not None and "port=" not in endpoint :
171181 if len (q ) > 0 :
172182 q += "&"
@@ -179,17 +189,25 @@ def _parse_forwarding_params(endpoint, service_port=None, service_host=None, log
179189 if len (q ) > 0 :
180190 q += "&"
181191 q += "loginNode=%s" % login_node
192+ if unix_socketfile is not None and "file=" not in endpoint :
193+ if len (q ) > 0 :
194+ q += "&"
195+ q += "file=%s" % login_node
182196 return parsed_url ._replace (query = q )
183197
184198
185- def open_tunnel (job , service_port = None , service_host = None , login_node = None , debug = False ):
199+ def open_tunnel (
200+ job , service_port = None , service_host = None , login_node = None , unix_socketfile = None , debug = False
201+ ):
186202 """open a tunnel to a service running on the HPC side
187203 and return the connected socket
188204 """
189205 endpoint = job .links ["forwarding" ]
190206 tr = job .transport ._clone ()
191207 tr .use_security_sessions = False
192- forwarder = Forwarder (tr , endpoint , service_port , service_host , login_node , debug )
208+ forwarder = Forwarder (
209+ tr , endpoint , service_port , service_host , login_node , unix_socketfile , debug
210+ )
193211 return forwarder .connect ()
194212
195213
0 commit comments