Skip to content

Commit 85a2f24

Browse files
committed
test: use single SSH connection for lifetime of microvm
Instead of creating new SSH connections every time we want to run a command inside the microvm, open a single ssh connection in the constructor of `SSHConnection`, and reuse it until we kill the microvm. Use the `fabric` SSH library to achieve this. Since fabric by default does not support opening its connections in specific network namespaces (e.g. ip netns exec), explicitly switch into the target network namespace before establishing the connected. Since entering network namspaces happens per-thread (instead of per-process), this is fine to do even in a highly-multithreaded pytest environment. Signed-off-by: Patrick Roy <[email protected]>
1 parent 9a8c5a8 commit 85a2f24

File tree

4 files changed

+68
-95
lines changed

4 files changed

+68
-95
lines changed

tests/framework/microvm.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ def __init__(
247247
self.mem_size_bytes = None
248248
self.cpu_template_name = None
249249

250+
self._ssh_connections = []
251+
250252
self._pre_cmd = []
251253
if numa_node:
252254
node_str = str(numa_node)
@@ -282,6 +284,10 @@ def kill(self):
282284
for monitor in self.monitors:
283285
monitor.stop()
284286

287+
# Cleanup all SSH connections
288+
for conn in self._ssh_connections:
289+
conn.close()
290+
285291
# We start with vhost-user backends,
286292
# because if we stop Firecracker first, the backend will want
287293
# to exit as well and this will cause a race condition.
@@ -1007,13 +1013,15 @@ def ssh_iface(self, iface_idx=0):
10071013
"""Return a cached SSH connection on a given interface id."""
10081014
guest_ip = list(self.iface.values())[iface_idx]["iface"].guest_ip
10091015
self.ssh_key = Path(self.ssh_key)
1010-
return net_tools.SSHConnection(
1011-
netns=self.netns.id,
1016+
connection = net_tools.SSHConnection(
1017+
netns_=self.netns.id,
10121018
ssh_key=self.ssh_key,
10131019
user="root",
10141020
host=guest_ip,
10151021
on_error=self._dump_debug_information,
10161022
)
1023+
self._ssh_connections.append(connection)
1024+
return connection
10171025

10181026
@property
10191027
def ssh(self):

tests/host_tools/network.py

Lines changed: 51 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
import ipaddress
66
import random
77
import string
8-
import subprocess
98
from dataclasses import dataclass, field
9+
from io import BytesIO
1010
from pathlib import Path
1111

12-
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed
12+
import netns
13+
from fabric import Connection
14+
from tenacity import retry, stop_after_attempt, wait_fixed
1315

1416
from framework import utils
17+
from framework.utils import CommandReturn
1518

1619

1720
class SSHConnection:
@@ -22,13 +25,14 @@ class SSHConnection:
2225
the hostname obtained from the MAC address, the username for logging into
2326
the image and the path of the ssh key.
2427
25-
This translates into an SSH connection as follows:
26-
ssh -i ssh_key_path username@hostname
28+
Uses the fabric library to establish a single connection once, and then
29+
keep it alive for the lifetime of the microvm, to avoid spurious failures
30+
due to reestablishing SSH connections for every single command sent.
2731
"""
2832

29-
def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None):
33+
def __init__(self, netns_, ssh_key: Path, host, user, *, on_error=None):
3034
"""Instantiate a SSH client and connect to a microVM."""
31-
self.netns = netns
35+
self.netns = netns_
3236
self.ssh_key = ssh_key
3337
# check that the key exists and the permissions are 0o400
3438
# This saves a lot of debugging time.
@@ -40,26 +44,23 @@ def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None):
4044

4145
self._on_error = None
4246

43-
self.options = [
44-
"-o",
45-
"LogLevel=ERROR",
46-
"-o",
47-
"ConnectTimeout=1",
48-
"-o",
49-
"StrictHostKeyChecking=no",
50-
"-o",
51-
"UserKnownHostsFile=/dev/null",
52-
"-o",
53-
"PreferredAuthentications=publickey",
54-
"-i",
55-
str(self.ssh_key),
56-
]
47+
self._connection = Connection(
48+
host,
49+
user,
50+
connect_timeout=1,
51+
connect_kwargs={
52+
"key_filename": str(self.ssh_key),
53+
"banner_timeout": 1,
54+
"auth_timeout": 1,
55+
},
56+
)
5757

5858
# _init_connection loops until it can connect to the guest
5959
# dumping debug state on every iteration is not useful or wanted, so
6060
# only dump it once if _all_ iterations fail.
6161
try:
62-
self._init_connection()
62+
with netns.NetNS(netns_):
63+
self._init_connection()
6364
except Exception as exc:
6465
if on_error:
6566
on_error(exc)
@@ -68,35 +69,15 @@ def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None):
6869

6970
self._on_error = on_error
7071

71-
@property
72-
def user_host(self):
73-
"""remote address for in SSH format <user>@<IP>"""
74-
return f"{self.user}@{self.host}"
75-
76-
def remote_path(self, path):
77-
"""Convert a path to remote"""
78-
return f"{self.user_host}:{path}"
79-
80-
def _scp(self, path1, path2, options):
81-
"""Copy files to/from the VM using scp."""
82-
self._exec(["scp", *options, path1, path2], check=True)
83-
84-
def scp_put(self, local_path, remote_path, recursive=False):
72+
def scp_put(self, local_path, remote_path):
8573
"""Copy files to the VM using scp."""
86-
opts = self.options.copy()
87-
if recursive:
88-
opts.append("-r")
89-
self._scp(local_path, self.remote_path(remote_path), opts)
74+
self._connection.put(local_path, remote_path)
9075

91-
def scp_get(self, remote_path, local_path, recursive=False):
76+
def scp_get(self, remote_path, local_path):
9277
"""Copy files from the VM using scp."""
93-
opts = self.options.copy()
94-
if recursive:
95-
opts.append("-r")
96-
self._scp(self.remote_path(remote_path), local_path, opts)
78+
self._connection.get(remote_path, local_path)
9779

9880
@retry(
99-
retry=retry_if_exception_type(ChildProcessError),
10081
wait=wait_fixed(0.5),
10182
stop=stop_after_attempt(20),
10283
reraise=True,
@@ -106,61 +87,43 @@ def _init_connection(self):
10687
10788
Since we're connecting to a microVM we just started, we'll probably
10889
have to wait for it to boot up and start the SSH server.
109-
We'll keep trying to execute a remote command that can't fail
110-
(`/bin/true`), until we get a successful (0) exit code.
90+
We'll keep trying to open the connection in a loop for 20 attempts with 0.5s
91+
delay. Each connection attempt has a timeout of 1s.
11192
"""
112-
self.check_output("true", timeout=100, debug=True)
93+
self._connection.open()
11394

114-
def run(self, cmd_string, timeout=None, *, check=False, debug=False):
95+
def run(self, cmd_string, timeout=None, *, check=False):
11596
"""
11697
Execute the command passed as a string in the ssh context.
117-
118-
If `debug` is set, pass `-vvv` to `ssh`. Note that this will clobber stderr.
11998
"""
120-
command = ["ssh", *self.options, self.user_host, cmd_string]
121-
122-
if debug:
123-
command.insert(1, "-vvv")
124-
125-
return self._exec(command, timeout, check=check)
126-
127-
def check_output(self, cmd_string, timeout=None, *, debug=False):
128-
"""Same as `run`, but raises an exception on non-zero return code of remote command"""
129-
return self.run(cmd_string, timeout, check=True, debug=debug)
130-
131-
def _exec(self, cmd, timeout=None, check=False):
132-
"""Private function that handles the ssh client invocation."""
133-
if self.netns is not None:
134-
cmd = ["ip", "netns", "exec", self.netns] + cmd
135-
13699
try:
137-
return utils.run_cmd(cmd, check=check, timeout=timeout)
100+
# - warn=True means "do not raise exception on non-zero exit code, instead just log", e.g.
101+
# it's the inverse of our "check" argument.
102+
# - hide=True means "do not always log stdout/stderr"
103+
# - in_stream=BytesIO(b"") is needed to immediately close stdin of the remote command
104+
# without this, command that only exit after their stdin is closed would hang forever
105+
# and this hang would bypass the pytest timeout.
106+
result = self._connection.run(
107+
cmd_string,
108+
timeout=timeout,
109+
warn=not check,
110+
hide=True,
111+
in_stream=BytesIO(b""),
112+
)
138113
except Exception as exc:
139114
if self._on_error:
140115
self._on_error(exc)
141116

142117
raise
118+
return CommandReturn(result.exited, result.stdout, result.stderr)
143119

144-
# pylint:disable=invalid-name
145-
def Popen(
146-
self,
147-
cmd: str,
148-
stdin=subprocess.DEVNULL,
149-
stdout=subprocess.PIPE,
150-
stderr=subprocess.PIPE,
151-
**kwargs,
152-
) -> subprocess.Popen:
153-
"""Execute the command in the guest and return a Popen object.
154-
155-
pop = uvm.ssh.Popen("while true; do echo $(date -Is) $RANDOM; sleep 1; done")
156-
pop.stdout.read(16)
157-
"""
158-
cmd = ["ssh", *self.options, self.user_host, cmd]
159-
if self.netns is not None:
160-
cmd = ["ip", "netns", "exec", self.netns] + cmd
161-
return subprocess.Popen(
162-
cmd, stdin=stdin, stdout=stdout, stderr=stderr, **kwargs
163-
)
120+
def check_output(self, cmd_string, timeout=None):
121+
"""Same as `run`, but raises an exception on non-zero return code of remote command"""
122+
return self.run(cmd_string, timeout, check=True)
123+
124+
def close(self):
125+
"""Closes this SSHConnection"""
126+
self._connection.close()
164127

165128

166129
def mac_from_ip(ip_address):

tests/integration_tests/functional/test_balloon.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44

55
import logging
66
import time
7-
from subprocess import TimeoutExpired
87

98
import pytest
9+
from invoke import CommandTimedOut
1010
from tenacity import retry, stop_after_attempt, wait_fixed
1111

1212
from framework.utils import check_output, get_free_mem_ssh
@@ -74,7 +74,7 @@ def make_guest_dirty_memory(ssh_connection, amount_mib=32):
7474
logger.error("while running: %s", cmd)
7575
logger.error("stdout: %s", stdout)
7676
logger.error("stderr: %s", stderr)
77-
except TimeoutExpired:
77+
except CommandTimedOut:
7878
# It's ok if this expires. Sometimes the SSH connection
7979
# gets killed by the OOM killer *after* the fillmem program
8080
# started. As a result, we can ignore timeouts here.

tests/integration_tests/functional/test_pause_resume.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,16 +51,18 @@ def test_pause_resume(uvm_nano):
5151
# Flush and reset metrics as they contain pre-pause data.
5252
microvm.flush_metrics()
5353

54-
# Verify guest is no longer active.
55-
with pytest.raises(ChildProcessError):
54+
# Verify guest is no longer active (by observing a failure to reconnect)
55+
with pytest.raises(TimeoutError):
56+
microvm.ssh.close()
5657
microvm.ssh.check_output("true")
5758

5859
# Verify emulation was indeed paused and no events from either
5960
# guest or host side were handled.
6061
verify_net_emulation_paused(microvm.flush_metrics())
6162

6263
# Verify guest is no longer active.
63-
with pytest.raises(ChildProcessError):
64+
with pytest.raises(TimeoutError):
65+
microvm.ssh.close()
6466
microvm.ssh.check_output("true")
6567

6668
# Pausing the microVM when it is already `Paused` is allowed

0 commit comments

Comments
 (0)