Skip to content

Commit f247ed2

Browse files
committed
forwarder: add support for domain socket; more flexible local server loop
1 parent 4e3fb74 commit f247ed2

File tree

1 file changed

+29
-11
lines changed

1 file changed

+29
-11
lines changed

pyunicore/forwarder.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,26 @@ def __init__(
1717
service_port=None,
1818
service_host=None,
1919
login_node=None,
20+
unix_socketfile=None,
2021
debug=False,
2122
):
2223
"""Creates a new Forwarder instance
23-
The remote service host/port can be already encoded in the endpoint, or given separately
24+
The remote service address (host/port or domain socket filename) can be
25+
already encoded in the endpoint, or given separately
2426
2527
Args:
2628
transport: the transport (security sessions should be OFF)
2729
endpoint: UNICORE REST API endpoint which can establish the forwarding
2830
service_port: the remote service port (if not already encoded in the endpoint)
2931
service_host: the (optional) remote service host (if not encoded in the endpoint)
3032
login_node: the /optional) login node to use (if not encoded in the endpoint)
33+
unix_socketfile: the name of the remote service's UNIX domain socket file
34+
(in the job working directory)
3135
debug: set to True for some debug output to the console
3236
"""
3337
self.endpoint = endpoint
3438
self.parsed_url = _parse_forwarding_params(
35-
self.endpoint, service_port, service_host, login_node
39+
self.endpoint, service_port, service_host, login_node, unix_socketfile
3640
)
3741
self.transport = transport
3842
self.quiet = not debug
@@ -101,10 +105,10 @@ def create_socket(self):
101105
def start_forwarding(self):
102106
self.quiet or print("Start forwarding.")
103107
threading.Thread(
104-
target=self.transfer, args=(self.client_socket, self.service_socket)
108+
target=self.transfer, args=(self.client_socket, self.service_socket, "local->remote")
105109
).start()
106110
threading.Thread(
107-
target=self.transfer, args=(self.service_socket, self.client_socket)
111+
target=self.transfer, args=(self.service_socket, self.client_socket, "remote->local")
108112
).start()
109113

110114
def stop_forwarding(self):
@@ -119,8 +123,8 @@ def stop_forwarding(self):
119123
except OSError:
120124
pass
121125

122-
def transfer(self, source, destination):
123-
desc = f"{source.getpeername()} --> {destination.getpeername()}"
126+
def transfer(self, source, destination, name=""):
127+
desc = f"{name}: {source.getpeername()} --> {destination.getpeername()}"
124128
self.quiet or print("Start TCP forwarding %s" % desc)
125129
buf_size = 32768
126130
while True:
@@ -132,7 +136,7 @@ def transfer(self, source, destination):
132136
self.quiet or print("Source is at EOF for %s" % desc)
133137
break
134138
except OSError as e:
135-
self.quiet or print("I/O ERROR for %s " % desc, e)
139+
self.quiet or print("I/O ERROR for %s" % desc, e)
136140
for s in source, destination:
137141
try:
138142
s.close()
@@ -141,7 +145,7 @@ def transfer(self, source, destination):
141145
break
142146
self.quiet or print("Stopping TCP forwarding %s" % desc)
143147

144-
def run(self, local_port):
148+
def run(self, local_port, keep_alive=False):
145149
"""open a listener, accept client connections and forward them to the backend"""
146150
with socket.socket() as server:
147151
server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
@@ -156,9 +160,13 @@ def run(self, local_port):
156160
self.quiet or print("Client %s connected." % str(self.client_socket.getpeername()))
157161
self.service_socket = self.connect()
158162
self.start_forwarding()
163+
if not keep_alive:
164+
break
159165

160166

161-
def _parse_forwarding_params(endpoint, service_port=None, service_host=None, login_node=None):
167+
def _parse_forwarding_params(
168+
endpoint, service_port=None, service_host=None, login_node=None, unix_socketfile=None
169+
):
162170
"""If not already present in the endpoint, the parameters like
163171
service_port are added.
164172
@@ -167,6 +175,8 @@ def _parse_forwarding_params(endpoint, service_port=None, service_host=None, log
167175
"""
168176
parsed_url = urlparse(endpoint)
169177
q = parsed_url.query
178+
if (unix_socketfile is not None) and (service_port is not None):
179+
raise ValueError("Only one of 'file' and 'service_host'/'service_port' is allowed")
170180
if service_port is not None and "port=" not in endpoint:
171181
if len(q) > 0:
172182
q += "&"
@@ -179,17 +189,25 @@ def _parse_forwarding_params(endpoint, service_port=None, service_host=None, log
179189
if len(q) > 0:
180190
q += "&"
181191
q += "loginNode=%s" % login_node
192+
if unix_socketfile is not None and "file=" not in endpoint:
193+
if len(q) > 0:
194+
q += "&"
195+
q += "file=%s" % login_node
182196
return parsed_url._replace(query=q)
183197

184198

185-
def open_tunnel(job, service_port=None, service_host=None, login_node=None, debug=False):
199+
def open_tunnel(
200+
job, service_port=None, service_host=None, login_node=None, unix_socketfile=None, debug=False
201+
):
186202
"""open a tunnel to a service running on the HPC side
187203
and return the connected socket
188204
"""
189205
endpoint = job.links["forwarding"]
190206
tr = job.transport._clone()
191207
tr.use_security_sessions = False
192-
forwarder = Forwarder(tr, endpoint, service_port, service_host, login_node, debug)
208+
forwarder = Forwarder(
209+
tr, endpoint, service_port, service_host, login_node, unix_socketfile, debug
210+
)
193211
return forwarder.connect()
194212

195213

0 commit comments

Comments
 (0)