Skip to content

Commit 354b67e

Browse files
committed
test: use single SSH connection for lifetime of microvm
Instead of creating new SSH connections every time we want to run a command inside the microvm, open a single ssh connection in the constructor of `SSHConnection`, and reuse it until we kill the microvm. Use the `fabric` SSH library to achieve this. Signed-off-by: Patrick Roy <[email protected]>
1 parent 9a8c5a8 commit 354b67e

File tree

4 files changed

+64
-87
lines changed

4 files changed

+64
-87
lines changed

tests/framework/microvm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ def __init__(
247247
self.mem_size_bytes = None
248248
self.cpu_template_name = None
249249

250+
self._ssh_connections = []
251+
250252
self._pre_cmd = []
251253
if numa_node:
252254
node_str = str(numa_node)
@@ -282,6 +284,10 @@ def kill(self):
282284
for monitor in self.monitors:
283285
monitor.stop()
284286

287+
# Cleanup all SSH connections
288+
for conn in self._ssh_connections:
289+
conn.close()
290+
285291
# We start with vhost-user backends,
286292
# because if we stop Firecracker first, the backend will want
287293
# to exit as well and this will cause a race condition.
@@ -1007,13 +1013,15 @@ def ssh_iface(self, iface_idx=0):
10071013
"""Return a cached SSH connection on a given interface id."""
10081014
guest_ip = list(self.iface.values())[iface_idx]["iface"].guest_ip
10091015
self.ssh_key = Path(self.ssh_key)
1010-
return net_tools.SSHConnection(
1011-
netns=self.netns.id,
1016+
connection = net_tools.SSHConnection(
1017+
netns_=self.netns.id,
10121018
ssh_key=self.ssh_key,
10131019
user="root",
10141020
host=guest_ip,
10151021
on_error=self._dump_debug_information,
10161022
)
1023+
self._ssh_connections.append(connection)
1024+
return connection
10171025

10181026
@property
10191027
def ssh(self):

tests/host_tools/network.py

Lines changed: 47 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
import ipaddress
66
import random
77
import string
8-
import subprocess
98
from dataclasses import dataclass, field
9+
from io import BytesIO
1010
from pathlib import Path
1111

12-
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
12+
import netns
13+
from fabric import Connection
14+
from tenacity import retry, stop_after_attempt, wait_fixed
1315

1416
from framework import utils
17+
from framework.utils import CommandReturn
1518

1619

1720
class SSHConnection:
@@ -22,13 +25,14 @@ class SSHConnection:
2225
the hostname obtained from the MAC address, the username for logging into
2326
the image and the path of the ssh key.
2427
25-
This translates into an SSH connection as follows:
26-
ssh -i ssh_key_path username@hostname
28+
Uses the fabric library to establish a single connection once, and then
29+
keep it alive for the lifetime of the microvm, to avoid spurious failures
30+
due to reestablishing SSH connections for every single command sent.
2731
"""
2832

29-
def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None):
33+
def __init__(self, netns_, ssh_key: Path, host, user, *, on_error=None):
3034
"""Instantiate a SSH client and connect to a microVM."""
31-
self.netns = netns
35+
self.netns = netns_
3236
self.ssh_key = ssh_key
3337
# check that the key exists and the permissions are 0o400
3438
# This saves a lot of debugging time.
@@ -40,26 +44,23 @@ def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None):
4044

4145
self._on_error = None
4246

43-
self.options = [
44-
"-o",
45-
"LogLevel=ERROR",
46-
"-o",
47-
"ConnectTimeout=1",
48-
"-o",
49-
"StrictHostKeyChecking=no",
50-
"-o",
51-
"UserKnownHostsFile=/dev/null",
52-
"-o",
53-
"PreferredAuthentications=publickey",
54-
"-i",
55-
str(self.ssh_key),
56-
]
47+
self._connection = Connection(
48+
host,
49+
user,
50+
connect_timeout=1,
51+
connect_kwargs={
52+
"key_filename": str(self.ssh_key),
53+
"banner_timeout": 1,
54+
"auth_timeout": 1,
55+
},
56+
)
5757

5858
# _init_connection loops until it can connect to the guest
5959
# dumping debug state on every iteration is not useful or wanted, so
6060
# only dump it once if _all_ iterations fail.
6161
try:
62-
self._init_connection()
62+
with netns.NetNS(netns_):
63+
self._init_connection()
6364
except Exception as exc:
6465
if on_error:
6566
on_error(exc)
@@ -68,35 +69,19 @@ def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None):
6869

6970
self._on_error = on_error
7071

71-
@property
72-
def user_host(self):
73-
"""remote address for in SSH format <user>@<IP>"""
74-
return f"{self.user}@{self.host}"
75-
76-
def remote_path(self, path):
77-
"""Convert a path to remote"""
78-
return f"{self.user_host}:{path}"
79-
8072
def _scp(self, path1, path2, options):
8173
"""Copy files to/from the VM using scp."""
8274
self._exec(["scp", *options, path1, path2], check=True)
8375

84-
def scp_put(self, local_path, remote_path, recursive=False):
76+
def scp_put(self, local_path, remote_path):
8577
"""Copy files to the VM using scp."""
86-
opts = self.options.copy()
87-
if recursive:
88-
opts.append("-r")
89-
self._scp(local_path, self.remote_path(remote_path), opts)
78+
self._connection.put(local_path, remote_path)
9079

91-
def scp_get(self, remote_path, local_path, recursive=False):
80+
def scp_get(self, remote_path, local_path):
9281
"""Copy files from the VM using scp."""
93-
opts = self.options.copy()
94-
if recursive:
95-
opts.append("-r")
96-
self._scp(self.remote_path(remote_path), local_path, opts)
82+
self._connection.get(remote_path, local_path)
9783

9884
@retry(
99-
retry=retry_if_exception_type(ChildProcessError),
10085
wait=wait_fixed(0.5),
10186
stop=stop_after_attempt(20),
10287
reraise=True,
@@ -106,61 +91,43 @@ def _init_connection(self):
10691
10792
Since we're connecting to a microVM we just started, we'll probably
10893
have to wait for it to boot up and start the SSH server.
109-
We'll keep trying to execute a remote command that can't fail
110-
(`/bin/true`), until we get a successful (0) exit code.
94+
We'll keep trying to open the connection in a loop for 20 attempts with 0.5s
95+
delay. Each connection attempt has a timeout of 1s.
11196
"""
112-
self.check_output("true", timeout=100, debug=True)
97+
self._connection.open()
11398

114-
def run(self, cmd_string, timeout=None, *, check=False, debug=False):
99+
def run(self, cmd_string, timeout=None, *, check=False):
115100
"""
116101
Execute the command passed as a string in the ssh context.
117-
118-
If `debug` is set, pass `-vvv` to `ssh`. Note that this will clobber stderr.
119102
"""
120-
command = ["ssh", *self.options, self.user_host, cmd_string]
121-
122-
if debug:
123-
command.insert(1, "-vvv")
124-
125-
return self._exec(command, timeout, check=check)
103+
return self._exec(cmd_string, timeout, check=check)
126104

127-
def check_output(self, cmd_string, timeout=None, *, debug=False):
105+
def check_output(self, cmd_string, timeout=None):
128106
"""Same as `run`, but raises an exception on non-zero return code of remote command"""
129-
return self.run(cmd_string, timeout, check=True, debug=debug)
107+
return self.run(cmd_string, timeout, check=True)
108+
109+
def close(self):
110+
"""Closes this SSHConnection"""
111+
self._connection.close()
130112

131113
def _exec(self, cmd, timeout=None, check=False):
132114
"""Private function that handles the ssh client invocation."""
133-
if self.netns is not None:
134-
cmd = ["ip", "netns", "exec", self.netns] + cmd
135-
136115
try:
137-
return utils.run_cmd(cmd, check=check, timeout=timeout)
116+
# - warn=True means "do not raise exception on non-zero exit code, instead just log", e.g.
117+
# it's the inverse of our "check" argument.
118+
# - hide=True means "do not always log stdout/stderr"
119+
# - in_stream=BytesIO(b"") is needed to immediately close stdin of the remote command
120+
# without this, command that only exit after their stdin is closed would hang forever
121+
# and this hang would bypass the pytest timeout.
122+
result = self._connection.run(
123+
cmd, timeout=timeout, warn=not check, hide=True, in_stream=BytesIO(b"")
124+
)
138125
except Exception as exc:
139126
if self._on_error:
140127
self._on_error(exc)
141128

142129
raise
143-
144-
# pylint:disable=invalid-name
145-
def Popen(
146-
self,
147-
cmd: str,
148-
stdin=subprocess.DEVNULL,
149-
stdout=subprocess.PIPE,
150-
stderr=subprocess.PIPE,
151-
**kwargs,
152-
) -> subprocess.Popen:
153-
"""Execute the command in the guest and return a Popen object.
154-
155-
pop = uvm.ssh.Popen("while true; do echo $(date -Is) $RANDOM; sleep 1; done")
156-
pop.stdout.read(16)
157-
"""
158-
cmd = ["ssh", *self.options, self.user_host, cmd]
159-
if self.netns is not None:
160-
cmd = ["ip", "netns", "exec", self.netns] + cmd
161-
return subprocess.Popen(
162-
cmd, stdin=stdin, stdout=stdout, stderr=stderr, **kwargs
163-
)
130+
return CommandReturn(result.exited, result.stdout, result.stderr)
164131

165132

166133
def mac_from_ip(ip_address):

tests/integration_tests/functional/test_balloon.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
import logging
66
import time
7-
from subprocess import TimeoutExpired
87

98
import pytest
9+
from invoke import CommandTimedOut
1010
from tenacity import retry, stop_after_attempt, wait_fixed
1111

1212
from framework.utils import check_output, get_free_mem_ssh
@@ -74,7 +74,7 @@ def make_guest_dirty_memory(ssh_connection, amount_mib=32):
7474
logger.error("while running: %s", cmd)
7575
logger.error("stdout: %s", stdout)
7676
logger.error("stderr: %s", stderr)
77-
except TimeoutExpired:
77+
except CommandTimedOut:
7878
# It's ok if this expires. Sometimes the SSH connection
7979
# gets killed by the OOM killer *after* the fillmem program
8080
# started. As a result, we can ignore timeouts here.

tests/integration_tests/functional/test_pause_resume.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,18 @@ def test_pause_resume(uvm_nano):
5151
# Flush and reset metrics as they contain pre-pause data.
5252
microvm.flush_metrics()
5353

54-
# Verify guest is no longer active.
55-
with pytest.raises(ChildProcessError):
54+
# Verify guest is no longer active (by observing a failure to reconnect)
55+
with pytest.raises(TimeoutError):
56+
microvm.ssh.close()
5657
microvm.ssh.check_output("true")
5758

5859
# Verify emulation was indeed paused and no events from either
5960
# guest or host side were handled.
6061
verify_net_emulation_paused(microvm.flush_metrics())
6162

6263
# Verify guest is no longer active.
63-
with pytest.raises(ChildProcessError):
64+
with pytest.raises(TimeoutError):
65+
microvm.ssh.close()
6466
microvm.ssh.check_output("true")
6567

6668
# Pausing the microVM when it is already `Paused` is allowed

0 commit comments

Comments
 (0)