Skip to content

Commit 382e5fd

Browse files
committed
v4.3.18: Read IdentityFile from SSH config + configurable host key policy
SSH tunnel now reads IdentityFile directives from ~/.ssh/config and passes them as key_filename to paramiko.SSHClient.connect(). Host-specific keys are tried first, then wildcard entries. Non-existent files are skipped. look_for_keys=False is maintained to prevent blind scanning of ~/.ssh/. Auth order: key_filename (specific->wildcard) -> agent -> password. New host_key_policy setting in [ssh tunnels] config section: - auto-add (default): accept unknown keys and add to known_hosts (TOFU) - warn: accept unknown keys but log a warning - reject: only connect to hosts already in ~/.ssh/known_hosts Made with ❤️ and 🤖 Claude
1 parent 93c3a1d commit 382e5fd

File tree

4 files changed

+224
-7
lines changed

4 files changed

+224
-7
lines changed

pgcli/main.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ def __init__(
309309
dsn_ssh_tunnel_config=c.get("dsn ssh tunnels"),
310310
logger=self.logger,
311311
allow_agent=str(ssh_tunnels_config.get("allow_agent", "True")).lower() == "true",
312+
host_key_policy=str(ssh_tunnels_config.get("host_key_policy", "auto-add")).lower(),
312313
)
313314
self.ssh_tunnel = None
314315

pgcli/pgclirc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,11 @@ float = ""
311311
# Use SSH agent for key authentication (default: True)
312312
# Set to False if you want to use only keys specified in ~/.ssh/config
313313
allow_agent = True
314+
# SSH host key verification policy (default: auto-add)
315+
# auto-add = accept unknown host keys and add to known_hosts (TOFU)
316+
# warn = accept unknown host keys but log a warning
317+
# reject = only connect to hosts already in ~/.ssh/known_hosts
318+
host_key_policy = auto-add
314319

315320
# ^example.*\.host$ = myuser:mypasswd@my.tunnel.com:4000
316321
# .*\.net = another.tunnel.com

pgcli/ssh_tunnel.py

Lines changed: 35 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,12 @@ class _NativeSSHTunnel:
7979
Binds a local TCP port and forwards connections through an SSH channel.
8080
"""
8181

82+
HOST_KEY_POLICIES = {
83+
"auto-add": paramiko.AutoAddPolicy,
84+
"warn": paramiko.WarningPolicy,
85+
"reject": paramiko.RejectPolicy,
86+
}
87+
8288
def __init__(
8389
self,
8490
ssh_hostname: str,
@@ -89,6 +95,8 @@ def __init__(
8995
ssh_password: Optional[str] = None,
9096
ssh_proxy: Optional[Any] = None,
9197
allow_agent: bool = True,
98+
key_filenames: Optional[list] = None,
99+
host_key_policy: str = "auto-add",
92100
logger: Optional[logging.Logger] = None,
93101
):
94102
self.ssh_hostname = ssh_hostname
@@ -99,6 +107,8 @@ def __init__(
99107
self.ssh_password = ssh_password
100108
self.ssh_proxy = ssh_proxy
101109
self.allow_agent = allow_agent
110+
self.key_filenames = key_filenames
111+
self.host_key_policy = host_key_policy
102112
self.logger = logger or logging.getLogger(__name__)
103113

104114
self._ssh_client: Optional[paramiko.SSHClient] = None
@@ -120,7 +130,8 @@ def start(self):
120130
"""Start SSH connection and local forwarding server."""
121131
self._ssh_client = paramiko.SSHClient()
122132
self._ssh_client.load_system_host_keys()
123-
self._ssh_client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
133+
policy_cls = self.HOST_KEY_POLICIES.get(self.host_key_policy, paramiko.AutoAddPolicy)
134+
self._ssh_client.set_missing_host_key_policy(policy_cls())
124135

125136
connect_kwargs: dict[str, Any] = {
126137
"hostname": self.ssh_hostname,
@@ -131,6 +142,8 @@ def start(self):
131142
"compress": False,
132143
"timeout": 10,
133144
}
145+
if self.key_filenames:
146+
connect_kwargs["key_filename"] = self.key_filenames
134147
if self.ssh_password:
135148
connect_kwargs["password"] = self.ssh_password
136149
if self.ssh_proxy:
@@ -189,6 +202,7 @@ def __init__(
189202
dsn_ssh_tunnel_config: Optional[dict] = None,
190203
logger: Optional[logging.Logger] = None,
191204
allow_agent: bool = True,
205+
host_key_policy: str = "auto-add",
192206
):
193207
"""
194208
Initialize SSH tunnel manager.
@@ -199,13 +213,15 @@ def __init__(
199213
dsn_ssh_tunnel_config: Dict of dsn_regex -> tunnel_url mappings
200214
logger: Logger instance for debug output
201215
allow_agent: Whether to allow SSH agent for key authentication (default True)
216+
host_key_policy: SSH host key policy: 'auto-add', 'warn', or 'reject' (default 'auto-add')
202217
"""
203218
self.ssh_tunnel_url = ssh_tunnel_url
204219
self.ssh_tunnel_config = ssh_tunnel_config or {}
205220
self.dsn_ssh_tunnel_config = dsn_ssh_tunnel_config or {}
206221
self.logger = logger or logging.getLogger(__name__)
207222
self.tunnel: Optional[_NativeSSHTunnel] = None
208223
self.allow_agent = allow_agent
224+
self.host_key_policy = host_key_policy
209225

210226
def find_tunnel_url(
211227
self,
@@ -289,10 +305,14 @@ def start_tunnel(
289305
ssh_username = tunnel_info.username
290306
ssh_proxy = None
291307

292-
# Read SSH config manually to get username/port/proxycommand
293-
# but NOT IdentityFile (which would cause passphrase prompts).
294-
# We rely on ssh-agent for key authentication instead.
308+
# Read SSH config for username/port/proxycommand/identityfile.
309+
# IdentityFile entries are read in order (host-specific first, wildcard after)
310+
# and passed as key_filename to paramiko. Paramiko tries each key in order,
311+
# skipping any that fail (e.g. passphrase-protected without agent).
312+
# look_for_keys remains False to prevent blind scanning of ~/.ssh/.
313+
# Auth order: key_filename (specific->wildcard) -> agent -> password
295314
ssh_config_path = os.path.expanduser("~/.ssh/config")
315+
key_filenames = []
296316
if ssh_hostname and os.path.isfile(ssh_config_path):
297317
try:
298318
ssh_config = paramiko.SSHConfig()
@@ -307,16 +327,21 @@ def start_tunnel(
307327
proxycommand = host_config.get("proxycommand")
308328
if proxycommand:
309329
ssh_proxy = paramiko.ProxyCommand(proxycommand)
330+
identity_files = host_config.get("identityfile", [])
331+
key_filenames = [os.path.expanduser(f) for f in identity_files
332+
if os.path.isfile(os.path.expanduser(f))]
333+
if key_filenames:
334+
self.logger.debug("SSH identity files from config: %s", key_filenames)
310335
except Exception as e:
311336
self.logger.warning("Could not read SSH config: %s", e)
312337

313338
if not ssh_username:
314339
ssh_username = getpass.getuser()
315340

316341
self.logger.debug(
317-
"Creating SSH tunnel: %s@%s:%d -> %s:%d (allow_agent=%s)",
342+
"Creating SSH tunnel: %s@%s:%d -> %s:%d (allow_agent=%s, key_files=%d)",
318343
ssh_username, ssh_hostname, ssh_port, host, int(port),
319-
self.allow_agent,
344+
self.allow_agent, len(key_filenames),
320345
)
321346

322347
try:
@@ -329,6 +354,8 @@ def start_tunnel(
329354
ssh_password=tunnel_info.password,
330355
ssh_proxy=ssh_proxy,
331356
allow_agent=self.allow_agent,
357+
key_filenames=key_filenames or None,
358+
host_key_policy=self.host_key_policy,
332359
logger=self.logger,
333360
)
334361
tunnel.start()
@@ -377,11 +404,13 @@ def get_tunnel_manager_from_config(
377404
# Extract allow_agent from ssh tunnels config (default True)
378405
ssh_tunnels_config = config.get("ssh tunnels", {})
379406
allow_agent = str(ssh_tunnels_config.get("allow_agent", "True")).lower() == "true"
407+
host_key_policy = str(ssh_tunnels_config.get("host_key_policy", "auto-add")).lower()
380408

381409
return SSHTunnelManager(
382410
ssh_tunnel_url=ssh_tunnel_url,
383411
ssh_tunnel_config=ssh_tunnels_config,
384412
dsn_ssh_tunnel_config=config.get("dsn ssh tunnels"),
385413
logger=logger,
386414
allow_agent=allow_agent,
415+
host_key_policy=host_key_policy,
387416
)

tests/test_ssh_tunnel.py

Lines changed: 183 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import logging
22
import os
3-
from unittest.mock import patch, MagicMock, ANY, call
3+
from unittest.mock import patch, MagicMock, ANY, call, mock_open
44

5+
import paramiko
56
import pytest
67
from configobj import ConfigObj
78
from click.testing import CliRunner
@@ -544,6 +545,176 @@ def test_proxy_command_passed(self, mock_native_tunnel):
544545
connect_kwargs = mock_native_tunnel["client"].connect.call_args[1]
545546
assert connect_kwargs["sock"] is mock_proxy
546547

548+
def test_key_filenames_passed_to_connect(self, mock_native_tunnel):
549+
"""Test that key_filenames are passed as key_filename to connect()."""
550+
key_files = ["/home/user/.ssh/id_ed25519", "/home/user/.ssh/id_rsa"]
551+
tunnel = _NativeSSHTunnel(
552+
ssh_hostname="bastion",
553+
ssh_port=22,
554+
remote_host="db.internal",
555+
remote_port=5432,
556+
ssh_username="testuser",
557+
key_filenames=key_files,
558+
)
559+
tunnel.start()
560+
561+
connect_kwargs = mock_native_tunnel["client"].connect.call_args[1]
562+
assert connect_kwargs["key_filename"] == key_files
563+
assert connect_kwargs["look_for_keys"] is False # Still disabled
564+
565+
def test_no_key_filenames_omits_key_filename(self, mock_native_tunnel):
566+
"""Test that key_filename is NOT passed when key_filenames is None."""
567+
tunnel = _NativeSSHTunnel(
568+
ssh_hostname="bastion",
569+
ssh_port=22,
570+
remote_host="db.internal",
571+
remote_port=5432,
572+
ssh_username="testuser",
573+
)
574+
tunnel.start()
575+
576+
connect_kwargs = mock_native_tunnel["client"].connect.call_args[1]
577+
assert "key_filename" not in connect_kwargs
578+
579+
def test_host_key_policy_auto_add(self, mock_native_tunnel):
580+
"""Test that auto-add policy sets AutoAddPolicy."""
581+
tunnel = _NativeSSHTunnel(
582+
ssh_hostname="bastion", ssh_port=22,
583+
remote_host="db.internal", remote_port=5432,
584+
host_key_policy="auto-add",
585+
)
586+
tunnel.start()
587+
policy_arg = mock_native_tunnel["client"].set_missing_host_key_policy.call_args[0][0]
588+
assert isinstance(policy_arg, paramiko.AutoAddPolicy)
589+
590+
def test_host_key_policy_warn(self, mock_native_tunnel):
591+
"""Test that warn policy sets WarningPolicy."""
592+
tunnel = _NativeSSHTunnel(
593+
ssh_hostname="bastion", ssh_port=22,
594+
remote_host="db.internal", remote_port=5432,
595+
host_key_policy="warn",
596+
)
597+
tunnel.start()
598+
policy_arg = mock_native_tunnel["client"].set_missing_host_key_policy.call_args[0][0]
599+
assert isinstance(policy_arg, paramiko.WarningPolicy)
600+
601+
def test_host_key_policy_reject(self, mock_native_tunnel):
602+
"""Test that reject policy sets RejectPolicy."""
603+
tunnel = _NativeSSHTunnel(
604+
ssh_hostname="bastion", ssh_port=22,
605+
remote_host="db.internal", remote_port=5432,
606+
host_key_policy="reject",
607+
)
608+
tunnel.start()
609+
policy_arg = mock_native_tunnel["client"].set_missing_host_key_policy.call_args[0][0]
610+
assert isinstance(policy_arg, paramiko.RejectPolicy)
611+
612+
def test_host_key_policy_default_is_auto_add(self, mock_native_tunnel):
613+
"""Test that default policy is auto-add."""
614+
tunnel = _NativeSSHTunnel(
615+
ssh_hostname="bastion", ssh_port=22,
616+
remote_host="db.internal", remote_port=5432,
617+
)
618+
tunnel.start()
619+
policy_arg = mock_native_tunnel["client"].set_missing_host_key_policy.call_args[0][0]
620+
assert isinstance(policy_arg, paramiko.AutoAddPolicy)
621+
622+
def test_host_key_policy_invalid_falls_back_to_auto_add(self, mock_native_tunnel):
623+
"""Test that invalid policy name falls back to AutoAddPolicy."""
624+
tunnel = _NativeSSHTunnel(
625+
ssh_hostname="bastion", ssh_port=22,
626+
remote_host="db.internal", remote_port=5432,
627+
host_key_policy="nonsense",
628+
)
629+
tunnel.start()
630+
policy_arg = mock_native_tunnel["client"].set_missing_host_key_policy.call_args[0][0]
631+
assert isinstance(policy_arg, paramiko.AutoAddPolicy)
632+
633+
634+
class TestSSHTunnelIdentityFile:
635+
"""Tests for IdentityFile reading from SSH config."""
636+
637+
def _make_manager_with_ssh_config(self, mock_native_tunnel, host_config, tunnel_url="ssh://bastion.example.com"):
638+
"""Helper: create manager, mock SSH config lookup, run start_tunnel."""
639+
mock_ssh_config = MagicMock()
640+
mock_ssh_config.lookup.return_value = host_config
641+
642+
# Determine which identity files "exist" on disk
643+
existing_files = set(host_config.get("_existing_files", host_config.get("identityfile", [])))
644+
existing_files.add("~/.ssh/config") # SSH config always exists
645+
646+
manager = SSHTunnelManager(
647+
ssh_tunnel_url=tunnel_url,
648+
logger=logging.getLogger("test"),
649+
)
650+
651+
with patch("pgcli.ssh_tunnel.os.path.expanduser", side_effect=lambda p: p), \
652+
patch("pgcli.ssh_tunnel.os.path.isfile", side_effect=lambda p: p in existing_files), \
653+
patch("pgcli.ssh_tunnel.paramiko.SSHConfig") as mock_config_cls, \
654+
patch("builtins.open", mock_open(read_data="")):
655+
mock_config_cls.return_value = mock_ssh_config
656+
host, port = manager.start_tunnel(host="db.internal", port=5432)
657+
658+
return mock_native_tunnel["client"].connect.call_args[1]
659+
660+
def test_start_tunnel_reads_identity_files(self, mock_native_tunnel):
661+
"""Test that start_tunnel reads IdentityFile from SSH config and passes to connect."""
662+
host_config = {
663+
"hostname": "bastion.example.com",
664+
"user": "tunneluser",
665+
"identityfile": ["/home/user/.ssh/id_ed25519_specific", "/home/user/.ssh/id_rsa_wildcard"],
666+
}
667+
668+
connect_kwargs = self._make_manager_with_ssh_config(mock_native_tunnel, host_config)
669+
670+
assert "key_filename" in connect_kwargs
671+
assert connect_kwargs["key_filename"] == [
672+
"/home/user/.ssh/id_ed25519_specific",
673+
"/home/user/.ssh/id_rsa_wildcard",
674+
]
675+
assert connect_kwargs["look_for_keys"] is False
676+
677+
def test_start_tunnel_skips_nonexistent_identity_files(self, mock_native_tunnel):
678+
"""Test that nonexistent IdentityFile entries are filtered out."""
679+
host_config = {
680+
"hostname": "bastion.example.com",
681+
"identityfile": ["/home/user/.ssh/id_ed25519_exists", "/home/user/.ssh/id_rsa_missing"],
682+
"_existing_files": ["/home/user/.ssh/id_ed25519_exists"], # only this one exists
683+
}
684+
685+
connect_kwargs = self._make_manager_with_ssh_config(mock_native_tunnel, host_config)
686+
687+
assert "key_filename" in connect_kwargs
688+
assert connect_kwargs["key_filename"] == ["/home/user/.ssh/id_ed25519_exists"]
689+
690+
def test_start_tunnel_no_identity_files_omits_key_filename(self, mock_native_tunnel):
691+
"""Test that key_filename is omitted when SSH config has no IdentityFile."""
692+
host_config = {
693+
"hostname": "bastion.example.com",
694+
"user": "tunneluser",
695+
}
696+
697+
connect_kwargs = self._make_manager_with_ssh_config(mock_native_tunnel, host_config)
698+
699+
assert "key_filename" not in connect_kwargs
700+
701+
def test_identity_file_order_preserved(self, mock_native_tunnel):
702+
"""Test that IdentityFile order is preserved (host-specific first, wildcard after)."""
703+
host_config = {
704+
"hostname": "bastion.example.com",
705+
"identityfile": [
706+
"/home/user/.ssh/id_ed25519_host", # host-specific (first)
707+
"/home/user/.ssh/id_ed25519_global", # wildcard (second)
708+
],
709+
}
710+
711+
connect_kwargs = self._make_manager_with_ssh_config(mock_native_tunnel, host_config)
712+
713+
assert connect_kwargs["key_filename"] == [
714+
"/home/user/.ssh/id_ed25519_host",
715+
"/home/user/.ssh/id_ed25519_global",
716+
]
717+
547718

548719
class TestGetTunnelManagerFromConfig:
549720
"""Tests for get_tunnel_manager_from_config function."""
@@ -596,3 +767,14 @@ def test_allow_agent_from_config(self):
596767
config = {"ssh tunnels": {"allow_agent": "False"}}
597768
manager = get_tunnel_manager_from_config(config)
598769
assert manager.allow_agent is False
770+
771+
def test_host_key_policy_from_config(self):
772+
"""Test host_key_policy is read from config."""
773+
config = {"ssh tunnels": {"host_key_policy": "reject"}}
774+
manager = get_tunnel_manager_from_config(config)
775+
assert manager.host_key_policy == "reject"
776+
777+
def test_host_key_policy_default(self):
778+
"""Test host_key_policy defaults to auto-add."""
779+
manager = get_tunnel_manager_from_config({})
780+
assert manager.host_key_policy == "auto-add"

0 commit comments

Comments
 (0)