55import ipaddress
66import random
77import string
8- import subprocess
98from dataclasses import dataclass , field
9+ from io import BytesIO
1010from pathlib import Path
1111
12- from tenacity import retry , retry_if_exception_type , stop_after_attempt , wait_fixed
12+ import netns
13+ from fabric import Connection
14+ from tenacity import retry , stop_after_attempt , wait_fixed
1315
1416from framework import utils
17+ from framework .utils import CommandReturn
1518
1619
1720class SSHConnection :
@@ -22,13 +25,14 @@ class SSHConnection:
2225 the hostname obtained from the MAC address, the username for logging into
2326 the image and the path of the ssh key.
2427
25- This translates into an SSH connection as follows:
26- ssh -i ssh_key_path username@hostname
28+ Uses the fabric library to establish a single connection once, and then
29+ keep it alive for the lifetime of the microvm, to avoid spurious failures
30+ due to reestablishing SSH connections for every single command sent.
2731 """
2832
29- def __init__ (self , netns , ssh_key : Path , host , user , * , on_error = None ):
33+ def __init__ (self , netns_ , ssh_key : Path , host , user , * , on_error = None ):
3034 """Instantiate a SSH client and connect to a microVM."""
31- self .netns = netns
35+ self .netns = netns_
3236 self .ssh_key = ssh_key
3337 # check that the key exists and the permissions are 0o400
3438 # This saves a lot of debugging time.
@@ -40,26 +44,23 @@ def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None):
4044
4145 self ._on_error = None
4246
43- self .options = [
44- "-o" ,
45- "LogLevel=ERROR" ,
46- "-o" ,
47- "ConnectTimeout=1" ,
48- "-o" ,
49- "StrictHostKeyChecking=no" ,
50- "-o" ,
51- "UserKnownHostsFile=/dev/null" ,
52- "-o" ,
53- "PreferredAuthentications=publickey" ,
54- "-i" ,
55- str (self .ssh_key ),
56- ]
47+ self ._connection = Connection (
48+ host ,
49+ user ,
50+ connect_timeout = 1 ,
51+ connect_kwargs = {
52+ "key_filename" : str (self .ssh_key ),
53+ "banner_timeout" : 1 ,
54+ "auth_timeout" : 1 ,
55+ },
56+ )
5757
5858 # _init_connection loops until it can connect to the guest
5959 # dumping debug state on every iteration is not useful or wanted, so
6060 # only dump it once if _all_ iterations fail.
6161 try :
62- self ._init_connection ()
62+ with netns .NetNS (netns_ ):
63+ self ._init_connection ()
6364 except Exception as exc :
6465 if on_error :
6566 on_error (exc )
@@ -68,35 +69,15 @@ def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None):
6869
6970 self ._on_error = on_error
7071
71- @property
72- def user_host (self ):
73- """remote address for in SSH format <user>@<IP>"""
74- return f"{ self .user } @{ self .host } "
75-
76- def remote_path (self , path ):
77- """Convert a path to remote"""
78- return f"{ self .user_host } :{ path } "
79-
80- def _scp (self , path1 , path2 , options ):
81- """Copy files to/from the VM using scp."""
82- self ._exec (["scp" , * options , path1 , path2 ], check = True )
83-
84- def scp_put (self , local_path , remote_path , recursive = False ):
72+ def scp_put (self , local_path , remote_path ):
8573 """Copy files to the VM using scp."""
86- opts = self .options .copy ()
87- if recursive :
88- opts .append ("-r" )
89- self ._scp (local_path , self .remote_path (remote_path ), opts )
74+ self ._connection .put (local_path , remote_path )
9075
91- def scp_get (self , remote_path , local_path , recursive = False ):
76+ def scp_get (self , remote_path , local_path ):
9277 """Copy files from the VM using scp."""
93- opts = self .options .copy ()
94- if recursive :
95- opts .append ("-r" )
96- self ._scp (self .remote_path (remote_path ), local_path , opts )
78+ self ._connection .get (remote_path , local_path )
9779
9880 @retry (
99- retry = retry_if_exception_type (ChildProcessError ),
10081 wait = wait_fixed (0.5 ),
10182 stop = stop_after_attempt (20 ),
10283 reraise = True ,
@@ -106,61 +87,43 @@ def _init_connection(self):
10687
10788 Since we're connecting to a microVM we just started, we'll probably
10889 have to wait for it to boot up and start the SSH server.
109- We'll keep trying to execute a remote command that can't fail
110- (`/bin/true`), until we get a successful (0) exit code .
90+ We'll keep trying to open the connection in a loop for 20 attempts with 0.5s
91+ delay. Each connection attempt has a timeout of 1s .
11192 """
112- self .check_output ( "true" , timeout = 100 , debug = True )
93+ self ._connection . open ( )
11394
114- def run (self , cmd_string , timeout = None , * , check = False , debug = False ):
95+ def run (self , cmd_string , timeout = None , * , check = False ):
11596 """
11697 Execute the command passed as a string in the ssh context.
117-
118- If `debug` is set, pass `-vvv` to `ssh`. Note that this will clobber stderr.
11998 """
120- command = ["ssh" , * self .options , self .user_host , cmd_string ]
121-
122- if debug :
123- command .insert (1 , "-vvv" )
124-
125- return self ._exec (command , timeout , check = check )
126-
127- def check_output (self , cmd_string , timeout = None , * , debug = False ):
128- """Same as `run`, but raises an exception on non-zero return code of remote command"""
129- return self .run (cmd_string , timeout , check = True , debug = debug )
130-
131- def _exec (self , cmd , timeout = None , check = False ):
132- """Private function that handles the ssh client invocation."""
133- if self .netns is not None :
134- cmd = ["ip" , "netns" , "exec" , self .netns ] + cmd
135-
13699 try :
137- return utils .run_cmd (cmd , check = check , timeout = timeout )
100+ # - warn=True means "do not raise exception on non-zero exit code, instead just log", e.g.
101+ # it's the inverse of our "check" argument.
102+ # - hide=True means "do not always log stdout/stderr"
103+ # - in_stream=BytesIO(b"") is needed to immediately close stdin of the remote command
104+ # without this, command that only exit after their stdin is closed would hang forever
105+ # and this hang would bypass the pytest timeout.
106+ result = self ._connection .run (
107+ cmd_string ,
108+ timeout = timeout ,
109+ warn = not check ,
110+ hide = True ,
111+ in_stream = BytesIO (b"" ),
112+ )
138113 except Exception as exc :
139114 if self ._on_error :
140115 self ._on_error (exc )
141116
142117 raise
118+ return CommandReturn (result .exited , result .stdout , result .stderr )
143119
144- # pylint:disable=invalid-name
145- def Popen (
146- self ,
147- cmd : str ,
148- stdin = subprocess .DEVNULL ,
149- stdout = subprocess .PIPE ,
150- stderr = subprocess .PIPE ,
151- ** kwargs ,
152- ) -> subprocess .Popen :
153- """Execute the command in the guest and return a Popen object.
154-
155- pop = uvm.ssh.Popen("while true; do echo $(date -Is) $RANDOM; sleep 1; done")
156- pop.stdout.read(16)
157- """
158- cmd = ["ssh" , * self .options , self .user_host , cmd ]
159- if self .netns is not None :
160- cmd = ["ip" , "netns" , "exec" , self .netns ] + cmd
161- return subprocess .Popen (
162- cmd , stdin = stdin , stdout = stdout , stderr = stderr , ** kwargs
163- )
120+ def check_output (self , cmd_string , timeout = None ):
121+ """Same as `run`, but raises an exception on non-zero return code of remote command"""
122+ return self .run (cmd_string , timeout , check = True )
123+
124+ def close (self ):
125+ """Closes this SSHConnection"""
126+ self ._connection .close ()
164127
165128
166129def mac_from_ip (ip_address ):
0 commit comments