1414from framework import utils
1515
1616
17+ class SSHConnectionException (Exception ):
18+ """
19+ Specific exception for ssh errors
20+ """
21+ pass
22+
23+
1724class SSHConnection :
1825 """
1926 SSHConnection encapsulates functionality for microVM SSH interaction.
@@ -58,13 +65,7 @@ def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None):
5865 # _init_connection loops until it can connect to the guest
5966 # dumping debug state on every iteration is not useful or wanted, so
6067 # only dump it once if _all_ iterations fail.
61- try :
62- self ._init_connection ()
63- except Exception as exc :
64- if on_error :
65- on_error (exc )
66-
67- raise
68+ self ._init_connection ()
6869
6970 self ._on_error = on_error
7071
@@ -79,7 +80,7 @@ def remote_path(self, path):
7980
8081 def _scp (self , path1 , path2 , options ):
8182 """Copy files to/from the VM using scp."""
82- self ._exec (["scp" , * options , path1 , path2 ], check = True )
83+ self ._exec (["scp" , * options , path1 , path2 ])
8384
8485 def scp_put (self , local_path , remote_path , recursive = False ):
8586 """Copy files to the VM using scp."""
@@ -96,7 +97,7 @@ def scp_get(self, remote_path, local_path, recursive=False):
9697 self ._scp (self .remote_path (remote_path ), local_path , opts )
9798
9899 @retry (
99- retry = retry_if_exception_type (ChildProcessError ),
100+ retry = retry_if_exception_type (SSHConnectionException ),
100101 wait = wait_fixed (0.5 ),
101102 stop = stop_after_attempt (20 ),
102103 reraise = True ,
@@ -109,11 +110,12 @@ def _init_connection(self):
109110 We'll keep trying to execute a remote command that can't fail
110111 (`/bin/true`), until we get a successful (0) exit code.
111112 """
112- self .check_output ("true" , timeout = 100 , debug = True )
113+ self .run ("true" , timeout = 100 , debug = True )
113114
114- def run (self , cmd_string , timeout = None , * , check = False , debug = False ):
115+ def run (self , cmd_string , timeout = None , * , check = True , debug = False ):
115116 """
116117 Execute the command passed as a string in the ssh context.
118+ By default raises an exception on non-zero return code of remote command.
117119
118120 If `debug` is set, pass `-vvv` to `ssh`. Note that this will clobber stderr.
119121 """
@@ -124,11 +126,7 @@ def run(self, cmd_string, timeout=None, *, check=False, debug=False):
124126
125127 return self ._exec (command , timeout , check = check )
126128
127- def check_output (self , cmd_string , timeout = None , * , debug = False ):
128- """Same as `run`, but raises an exception on non-zero return code of remote command"""
129- return self .run (cmd_string , timeout , check = True , debug = debug )
130-
131- def _exec (self , cmd , timeout = None , check = False ):
129+ def _exec (self , cmd , timeout = None , check = True ):
132130 """Private function that handles the ssh client invocation."""
133131 if self .netns is not None :
134132 cmd = ["ip" , "netns" , "exec" , self .netns ] + cmd
0 commit comments