Skip to content

Commit 2015f12

Browse files
committed
Changed fake server to threading to make it compatible with paramiko's ServerInterface. Changed tests to use new fake server
1 parent 038059e commit 2015f12

File tree

2 files changed

+78
-57
lines changed

2 files changed

+78
-57
lines changed

fake_server/fake_server.py

Lines changed: 65 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,22 @@
11
#!/usr/bin/env python
22

3-
"""Fake SSH server to test our SSH clients.
4-
Supports execution of commands via exec_command. Does _not_ support interactive shells,
5-
our clients do not use them.
6-
Server private key is hardcoded, server listen code inspired by demo_server.py in paramiko repository"""
3+
"""
4+
Fake SSH server to test our SSH clients.
5+
Supports execution of commands via exec_command. Does _not_ support interactive \
6+
shells, our clients do not use them.
7+
Server private key is hardcoded, server listen code inspired by demo_server.py in \
8+
paramiko repository
9+
"""
710

811
import os
9-
from gevent import socket
10-
import gevent.event
12+
import socket
13+
import threading
14+
from threading import Event
1115
import sys
1216
import traceback
1317
import logging
1418
import paramiko
19+
import time
1520

1621
logger = logging.getLogger(__name__)
1722
paramiko_logger = logging.getLogger('paramiko.transport')
@@ -20,7 +25,7 @@
2025

2126
class Server (paramiko.ServerInterface):
2227
def __init__(self, cmd_req_response = {}, fail_auth = False):
23-
self.event = gevent.event.Event()
28+
self.event = Event()
2429
self.cmd_req_response = cmd_req_response
2530
self.fail_auth = fail_auth
2631

@@ -56,60 +61,77 @@ def check_channel_exec_request(self, channel, cmd):
5661
self.event.set()
5762
return True
5863

59-
def _make_socket(listen_ip, listen_port):
64+
def make_socket(listen_ip):
65+
"""Make socket on given address and available port chosen by OS"""
6066
try:
6167
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
6268
sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
63-
sock.bind((listen_ip, listen_port))
69+
sock.bind((listen_ip, 0))
6470
except Exception, e:
6571
logger.error('Failed to bind to address - %s' % (str(e),))
6672
traceback.print_exc()
6773
return
6874
return sock
6975

70-
def listen(cmd_req_response, listen_ip = '127.0.0.1', listen_port = 2200, fail_auth = False):
71-
"""Run a fake ssh server and given a cmd_to_run, send given response"""
72-
sock = _make_socket(listen_ip, listen_port)
76+
def listen(cmd_req_response, sock, fail_auth = False):
77+
"""Run a fake ssh server and given a cmd_to_run, send given \
78+
response to client connection. Returns (server, socket) tuple \
79+
where server is a joinable server thread and socket is listening \
80+
socket of server."""
81+
# sock = _make_socket(listen_ip)
82+
listen_ip, listen_port = sock.getsockname()
7383
if not sock:
7484
logger.error("Could not establish listening connection on %s:%s", listen_ip, listen_port)
7585
return
7686
try:
7787
sock.listen(100)
7888
logger.info('Listening for connection on %s:%s..', listen_ip, listen_port)
79-
client, addr = sock.accept()
8089
except Exception, e:
81-
logger.error('*** Listen/accept failed: %s' % (str(e),))
90+
logger.error('*** Listen failed: %s' % (str(e),))
8291
traceback.print_exc()
8392
return
93+
accept_thread = threading.Thread(target=handle_ssh_connection,
94+
args=(cmd_req_response, sock,),
95+
kwargs={'fail_auth' : fail_auth},)
96+
accept_thread.start()
97+
return accept_thread
98+
99+
def _handle_ssh_connection(cmd_req_response, t, client, addr, fail_auth = False):
100+
try:
101+
t.load_server_moduli()
102+
except:
103+
return
104+
t.add_server_key(host_key)
105+
server = Server(cmd_req_response = cmd_req_response, fail_auth = fail_auth)
106+
try:
107+
t.start_server(server=server)
108+
except paramiko.SSHException, _:
109+
logger.error('SSH negotiation failed.')
110+
return
111+
return _accept_ssh_data(t, server)
112+
113+
def _accept_ssh_data(t, server):
114+
chan = t.accept(20)
115+
if not chan:
116+
logger.error("Could not establish channel")
117+
return
118+
logger.info("Authenticated..")
119+
chan.send_ready()
120+
server.event.wait(10)
121+
if not server.event.isSet():
122+
logger.error('Client never sent command')
123+
chan.close()
124+
return
125+
while not chan.send_ready():
126+
time.sleep(.5)
127+
chan.close()
128+
129+
def handle_ssh_connection(cmd_req_response, sock, fail_auth = False):
130+
client, addr = sock.accept()
84131
logger.info('Got connection..')
85132
try:
86133
t = paramiko.Transport(client)
87-
try:
88-
t.load_server_moduli()
89-
except:
90-
raise
91-
t.add_server_key(host_key)
92-
server = Server(cmd_req_response = cmd_req_response, fail_auth = fail_auth)
93-
try:
94-
t.start_server(server=server)
95-
except paramiko.SSHException, _:
96-
logger.error('SSH negotiation failed.')
97-
return
98-
chan = t.accept(20)
99-
if not chan:
100-
logger.error("Could not establish channel")
101-
return
102-
logger.info("Authenticated..")
103-
chan.send_ready()
104-
server.event.wait(10)
105-
if not server.event.isSet():
106-
logger.error('Client never sent command')
107-
chan.close()
108-
return
109-
while not chan.send_ready():
110-
gevent.sleep(.5)
111-
chan.close()
112-
134+
_handle_ssh_connection(cmd_req_response, t, client, addr, fail_auth=fail_auth)
113135
except Exception, e:
114136
logger.error('*** Caught exception: %s: %s' % (str(e.__class__), str(e),))
115137
traceback.print_exc()
@@ -122,4 +144,6 @@ def listen(cmd_req_response, listen_ip = '127.0.0.1', listen_port = 2200, fail_a
122144
if __name__ == "__main__":
123145
logging.basicConfig()
124146
logger.setLevel(logging.DEBUG)
125-
listen({'fake' : 'fake response' + os.linesep})
147+
sock = make_socket('127.0.0.1')
148+
server = listen({'fake' : 'fake response' + os.linesep}, sock)
149+
server.join()

tests/test_pssh_client.py

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33
"""Unittests for parallel-ssh"""
44

55
import unittest
6-
import gevent
76
from pssh import ParallelSSHClient, UnknownHostException, \
87
AuthenticationException, ConnectionErrorException, _setup_logger
9-
from fake_server.fake_server import listen, logger as server_logger
8+
from fake_server.fake_server import listen, make_socket, logger as server_logger
109
import random
1110
import logging
1211

@@ -17,35 +16,33 @@ class ParallelSSHClientTest(unittest.TestCase):
1716
def setUp(self):
1817
self.fake_cmd = 'fake cmd'
1918
self.fake_resp = 'fake response'
19+
self.listener = make_socket('127.0.0.1')
20+
self.listen_port = self.listener.getsockname()[1]
21+
22+
def cleanUp(self):
23+
del self.listener
2024

2125
def test_pssh_client_exec_command(self):
22-
listen_port = random.randint(1026, 65534)
23-
server = gevent.spawn(listen, { self.fake_cmd : self.fake_resp }, listen_port = listen_port)
24-
client = ParallelSSHClient(['localhost'], port=listen_port)
25-
gevent.sleep(0)
26+
server = listen({ self.fake_cmd : self.fake_resp }, self.listener)
27+
client = ParallelSSHClient(['localhost'], port=self.listen_port)
2628
cmd = client.exec_command(self.fake_cmd)[0]
2729
output = client.get_stdout(cmd)
2830
expected = {'localhost' : {'exit_code' : 0}}
2931
self.assertEqual(expected, output,
3032
msg = "Got unexpected command output - %s" % (output,))
31-
server.kill()
3233
del client
34+
server.join()
3335

3436
def test_pssh_client_auth_failure(self):
35-
listen_port = random.randint(2048, 65534)
36-
server = gevent.spawn(listen, { self.fake_cmd : self.fake_resp },
37-
listen_port=listen_port, fail_auth=True, )
38-
client = ParallelSSHClient(['localhost'], port=listen_port)
39-
gevent.sleep(0)
40-
server.join(1)
37+
server = listen({ self.fake_cmd : self.fake_resp },
38+
self.listener, fail_auth=True)
39+
client = ParallelSSHClient(['localhost'], port=self.listen_port)
4140
cmd = client.exec_command(self.fake_cmd)[0]
42-
server.join(1)
4341
# Handle exception
4442
try:
4543
cmd.get()
4644
raise Exception("Expected AuthenticationException, got none")
4745
except AuthenticationException:
4846
pass
49-
server.kill()
5047
del client
51-
48+
server.join()

0 commit comments

Comments
 (0)