55import ipaddress
66import random
77import string
8- import subprocess
98from dataclasses import dataclass , field
9+ from io import BytesIO
1010from pathlib import Path
1111
12- from tenacity import retry , retry_if_exception_type , stop_after_attempt , wait_fixed
12+ import netns
13+ from fabric import Connection
14+ from tenacity import retry , stop_after_attempt , wait_fixed
1315
1416from framework import utils
17+ from framework .utils import CommandReturn
1518
1619
1720class SSHConnection :
@@ -22,13 +25,14 @@ class SSHConnection:
2225 the hostname obtained from the MAC address, the username for logging into
2326 the image and the path of the ssh key.
2427
25- This translates into an SSH connection as follows:
26- ssh -i ssh_key_path username@hostname
28+ Uses the fabric library to establish a single connection once, and then
29+ keep it alive for the lifetime of the microvm, to avoid spurious failures
30+ due to reestablishing SSH connections for every single command sent.
2731 """
2832
29- def __init__ (self , netns , ssh_key : Path , host , user , * , on_error = None ):
33+ def __init__ (self , netns_ , ssh_key : Path , host , user , * , on_error = None ):
3034 """Instantiate a SSH client and connect to a microVM."""
31- self .netns = netns
35+ self .netns = netns_
3236 self .ssh_key = ssh_key
3337 # check that the key exists and the permissions are 0o400
3438 # This saves a lot of debugging time.
@@ -40,26 +44,23 @@ def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None):
4044
4145 self ._on_error = None
4246
43- self .options = [
44- "-o" ,
45- "LogLevel=ERROR" ,
46- "-o" ,
47- "ConnectTimeout=1" ,
48- "-o" ,
49- "StrictHostKeyChecking=no" ,
50- "-o" ,
51- "UserKnownHostsFile=/dev/null" ,
52- "-o" ,
53- "PreferredAuthentications=publickey" ,
54- "-i" ,
55- str (self .ssh_key ),
56- ]
47+ self ._connection = Connection (
48+ host ,
49+ user ,
50+ connect_timeout = 1 ,
51+ connect_kwargs = {
52+ "key_filename" : str (self .ssh_key ),
53+ "banner_timeout" : 1 ,
54+ "auth_timeout" : 1 ,
55+ },
56+ )
5757
5858 # _init_connection loops until it can connect to the guest
5959 # dumping debug state on every iteration is not useful or wanted, so
6060 # only dump it once if _all_ iterations fail.
6161 try :
62- self ._init_connection ()
62+ with netns .NetNS (netns_ ):
63+ self ._init_connection ()
6364 except Exception as exc :
6465 if on_error :
6566 on_error (exc )
@@ -68,35 +69,19 @@ def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None):
6869
6970 self ._on_error = on_error
7071
71- @property
72- def user_host (self ):
73- """remote address for in SSH format <user>@<IP>"""
74- return f"{ self .user } @{ self .host } "
75-
76- def remote_path (self , path ):
77- """Convert a path to remote"""
78- return f"{ self .user_host } :{ path } "
79-
8072 def _scp (self , path1 , path2 , options ):
8173 """Copy files to/from the VM using scp."""
8274 self ._exec (["scp" , * options , path1 , path2 ], check = True )
8375
84- def scp_put (self , local_path , remote_path , recursive = False ):
76+ def scp_put (self , local_path , remote_path ):
8577 """Copy files to the VM using scp."""
86- opts = self .options .copy ()
87- if recursive :
88- opts .append ("-r" )
89- self ._scp (local_path , self .remote_path (remote_path ), opts )
78+ self ._connection .put (local_path , remote_path )
9079
91- def scp_get (self , remote_path , local_path , recursive = False ):
80+ def scp_get (self , remote_path , local_path ):
9281 """Copy files from the VM using scp."""
93- opts = self .options .copy ()
94- if recursive :
95- opts .append ("-r" )
96- self ._scp (self .remote_path (remote_path ), local_path , opts )
82+ self ._connection .get (remote_path , local_path )
9783
9884 @retry (
99- retry = retry_if_exception_type (ChildProcessError ),
10085 wait = wait_fixed (0.5 ),
10186 stop = stop_after_attempt (20 ),
10287 reraise = True ,
@@ -106,61 +91,43 @@ def _init_connection(self):
10691
10792 Since we're connecting to a microVM we just started, we'll probably
10893 have to wait for it to boot up and start the SSH server.
109- We'll keep trying to execute a remote command that can't fail
110- (`/bin/true`), until we get a successful (0) exit code .
94+ We'll keep trying to open the connection in a loop for 20 attempts with 0.5s
95+ delay. Each connection attempt has a timeout of 1s .
11196 """
112- self .check_output ( "true" , timeout = 100 , debug = True )
97+ self ._connection . open ( )
11398
114- def run (self , cmd_string , timeout = None , * , check = False , debug = False ):
99+ def run (self , cmd_string , timeout = None , * , check = False ):
115100 """
116101 Execute the command passed as a string in the ssh context.
117-
118- If `debug` is set, pass `-vvv` to `ssh`. Note that this will clobber stderr.
119102 """
120- command = ["ssh" , * self .options , self .user_host , cmd_string ]
121-
122- if debug :
123- command .insert (1 , "-vvv" )
124-
125- return self ._exec (command , timeout , check = check )
103+ return self ._exec (cmd_string , timeout , check = check )
126104
127- def check_output (self , cmd_string , timeout = None , * , debug = False ):
105+ def check_output (self , cmd_string , timeout = None ):
128106 """Same as `run`, but raises an exception on non-zero return code of remote command"""
129- return self .run (cmd_string , timeout , check = True , debug = debug )
107+ return self .run (cmd_string , timeout , check = True )
108+
109+ def close (self ):
110+ """Closes this SSHConnection"""
111+ self ._connection .close ()
130112
131113 def _exec (self , cmd , timeout = None , check = False ):
132114 """Private function that handles the ssh client invocation."""
133- if self .netns is not None :
134- cmd = ["ip" , "netns" , "exec" , self .netns ] + cmd
135-
136115 try :
137- return utils .run_cmd (cmd , check = check , timeout = timeout )
116+ # - warn=True means "do not raise exception on non-zero exit code, instead just log", e.g.
117+ # it's the inverse of our "check" argument.
118+ # - hide=True means "do not always log stdout/stderr"
119+ # - in_stream=BytesIO(b"") is needed to immediately close stdin of the remote command
120+ # without this, command that only exit after their stdin is closed would hang forever
121+ # and this hang would bypass the pytest timeout.
122+ result = self ._connection .run (
123+ cmd , timeout = timeout , warn = not check , hide = True , in_stream = BytesIO (b"" )
124+ )
138125 except Exception as exc :
139126 if self ._on_error :
140127 self ._on_error (exc )
141128
142129 raise
143-
144- # pylint:disable=invalid-name
145- def Popen (
146- self ,
147- cmd : str ,
148- stdin = subprocess .DEVNULL ,
149- stdout = subprocess .PIPE ,
150- stderr = subprocess .PIPE ,
151- ** kwargs ,
152- ) -> subprocess .Popen :
153- """Execute the command in the guest and return a Popen object.
154-
155- pop = uvm.ssh.Popen("while true; do echo $(date -Is) $RANDOM; sleep 1; done")
156- pop.stdout.read(16)
157- """
158- cmd = ["ssh" , * self .options , self .user_host , cmd ]
159- if self .netns is not None :
160- cmd = ["ip" , "netns" , "exec" , self .netns ] + cmd
161- return subprocess .Popen (
162- cmd , stdin = stdin , stdout = stdout , stderr = stderr , ** kwargs
163- )
130+ return CommandReturn (result .exited , result .stdout , result .stderr )
164131
165132
166133def mac_from_ip (ip_address ):
0 commit comments