Skip to content

Commit a050131

Browse files
committed
Enable .pgpass support for SSH tunnel connections
Preserve original hostname for .pgpass lookup using PostgreSQL's host/hostaddr parameters: host keeps the original DB hostname (for .pgpass and SSL), hostaddr gets 127.0.0.1 (the tunnel endpoint). Changes: - main.py: Use hostaddr instead of replacing host with 127.0.0.1 - main.py: Pass dsn in connect_uri() for proper .pgpass handling - pgexecute.py: Simplify DSN filtering to keep dsn, password, hostaddr - tests: Add 4 new tests, update existing to verify host preservation Made with ❤️ and 🤖 Claude
1 parent a0c2ee4 commit a050131

File tree

4 files changed

+96
-25
lines changed

4 files changed

+96
-25
lines changed

changelog.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ Upcoming (TBD)
44
Features:
55
---------
66
* Add support for `\\T` prompt escape sequence to display transaction status (similar to psql's `%x`).
7+
* Enable ``.pgpass`` support for SSH tunnel connections.
8+
* Preserve original hostname for ``.pgpass`` lookup using PostgreSQL's ``hostaddr`` parameter
9+
* SSH tunnel endpoint (``127.0.0.1``) is passed via ``hostaddr``, keeping ``host`` for ``.pgpass``
10+
* Works with both DSN and host/port connection styles
711

812
4.4.0 (2025-12-24)
913
==================

pgcli/main.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,8 @@ def connect_uri(self, uri):
616616
kwargs = conninfo_to_dict(uri)
617617
remap = {"dbname": "database", "password": "passwd"}
618618
kwargs = {remap.get(k, k): v for k, v in kwargs.items()}
619-
self.connect(**kwargs)
619+
# Pass the original URI as dsn parameter for .pgpass support with SSH tunnels
620+
self.connect(dsn=uri, **kwargs)
620621

621622
def connect(self, database="", host="", user="", port="", passwd="", dsn="", **kwargs):
622623
# Connect to the database.
@@ -709,11 +710,15 @@ def should_ask_for_password(exc):
709710
self.logger.handlers = logger_handlers
710711

711712
atexit.register(self.ssh_tunnel.stop)
712-
host = "127.0.0.1"
713+
# Preserve original host for .pgpass lookup and SSL certificate verification.
714+
# Use hostaddr to specify the actual connection endpoint (SSH tunnel).
715+
hostaddr = "127.0.0.1"
713716
port = self.ssh_tunnel.local_bind_ports[0]
714717

715718
if dsn:
716-
dsn = make_conninfo(dsn, host=host, port=port)
719+
dsn = make_conninfo(dsn, host=host, hostaddr=hostaddr, port=port)
720+
else:
721+
kwargs["hostaddr"] = hostaddr
717722

718723
# Attempt to connect to the database.
719724
# Note that passwd may be empty on the first attempt. If connection

pgcli/pgexecute.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,11 @@ def connect(
212212
new_params.update(kwargs)
213213

214214
if new_params["dsn"]:
215-
new_params = {"dsn": new_params["dsn"], "password": new_params["password"]}
215+
# When using DSN, only keep dsn, password, and hostaddr (for SSH tunnels)
216+
new_params = {
217+
k: v for k, v in new_params.items()
218+
if k in ("dsn", "password", "hostaddr")
219+
}
216220

217221
if new_params["password"]:
218222
new_params["dsn"] = make_conninfo(new_params["dsn"], password=new_params.pop("password"))

tests/test_ssh_tunnel.py

Lines changed: 79 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -50,15 +50,13 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
5050
mock_pgexecute.assert_called_once()
5151

5252
call_args, call_kwargs = mock_pgexecute.call_args
53-
assert call_args == (
54-
db_params["database"],
55-
db_params["user"],
56-
db_params["passwd"],
57-
"127.0.0.1",
58-
pgcli.ssh_tunnel.local_bind_ports[0],
59-
"",
60-
notify_callback,
61-
)
53+
# Original host is preserved for .pgpass lookup, hostaddr has tunnel endpoint
54+
assert call_args[0] == db_params["database"] # database
55+
assert call_args[1] == db_params["user"] # user
56+
assert call_args[2] == db_params["passwd"] # passwd
57+
assert call_args[3] == db_params["host"] # original host preserved
58+
assert call_args[4] == pgcli.ssh_tunnel.local_bind_ports[0] # tunnel port
59+
assert call_kwargs.get("hostaddr") == "127.0.0.1" # tunnel endpoint via hostaddr
6260
mock_ssh_tunnel_forwarder.reset_mock()
6361
mock_pgexecute.reset_mock()
6462

@@ -86,15 +84,10 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
8684
mock_pgexecute.assert_called_once()
8785

8886
call_args, call_kwargs = mock_pgexecute.call_args
89-
assert call_args == (
90-
db_params["database"],
91-
db_params["user"],
92-
db_params["passwd"],
93-
"127.0.0.1",
94-
pgcli.ssh_tunnel.local_bind_ports[0],
95-
"",
96-
notify_callback,
97-
)
87+
# Original host is preserved, hostaddr has tunnel endpoint
88+
assert call_args[3] == db_params["host"] # original host preserved
89+
assert call_args[4] == pgcli.ssh_tunnel.local_bind_ports[0] # tunnel port
90+
assert call_kwargs.get("hostaddr") == "127.0.0.1"
9891
mock_ssh_tunnel_forwarder.reset_mock()
9992
mock_pgexecute.reset_mock()
10093

@@ -104,13 +97,78 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
10497
pgcli = PGCli(ssh_tunnel_url=tunnel_url)
10598
pgcli.connect(dsn=dsn)
10699

107-
expected_dsn = f"user={db_params['user']} password={db_params['passwd']} host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}"
108-
109100
mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params)
110101
mock_pgexecute.assert_called_once()
111102

112103
call_args, call_kwargs = mock_pgexecute.call_args
113-
assert expected_dsn in call_args
104+
# DSN should contain original host AND hostaddr for tunnel
105+
dsn_arg = call_args[5] # dsn is 6th positional arg
106+
assert f"host={db_params['host']}" in dsn_arg
107+
assert "hostaddr=127.0.0.1" in dsn_arg
108+
assert f"port={pgcli.ssh_tunnel.local_bind_ports[0]}" in dsn_arg
109+
110+
111+
def test_ssh_tunnel_preserves_original_host_for_pgpass(
112+
mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
113+
) -> None:
114+
"""Verify that the original hostname is preserved for .pgpass lookup."""
115+
tunnel_url = "bastion.example.com"
116+
original_host = "production.db.example.com"
117+
118+
pgcli = PGCli(ssh_tunnel_url=tunnel_url)
119+
pgcli.connect(database="mydb", host=original_host, user="dbuser", passwd="dbpass")
120+
121+
call_args, call_kwargs = mock_pgexecute.call_args
122+
# host parameter should be the original, not 127.0.0.1
123+
assert call_args[3] == original_host
124+
# hostaddr should be the tunnel endpoint
125+
assert call_kwargs.get("hostaddr") == "127.0.0.1"
126+
127+
128+
def test_ssh_tunnel_with_dsn_preserves_host(
129+
mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock
130+
) -> None:
131+
"""DSN connections should include hostaddr for tunnel while preserving host."""
132+
tunnel_url = "bastion.example.com"
133+
dsn = "host=production.db.example.com port=5432 dbname=mydb user=dbuser"
134+
135+
pgcli = PGCli(ssh_tunnel_url=tunnel_url)
136+
pgcli.connect(dsn=dsn)
137+
138+
call_args, call_kwargs = mock_pgexecute.call_args
139+
dsn_arg = call_args[5]
140+
assert "host=production.db.example.com" in dsn_arg
141+
assert "hostaddr=127.0.0.1" in dsn_arg
142+
143+
144+
def test_no_ssh_tunnel_does_not_set_hostaddr(mock_pgexecute: MagicMock) -> None:
145+
"""Without SSH tunnel, hostaddr should not be set."""
146+
pgcli = PGCli()
147+
pgcli.connect(database="mydb", host="localhost", user="user", passwd="pass")
148+
149+
call_args, call_kwargs = mock_pgexecute.call_args
150+
assert "hostaddr" not in call_kwargs
151+
152+
153+
def test_connect_uri_without_ssh_tunnel(mock_pgexecute: MagicMock) -> None:
154+
"""connect_uri should work normally without SSH tunnel."""
155+
pgcli = PGCli()
156+
pgcli.connect_uri("postgresql://user:pass@localhost/mydb")
157+
158+
mock_pgexecute.assert_called_once()
159+
call_args, call_kwargs = mock_pgexecute.call_args
160+
assert "hostaddr" not in call_kwargs
161+
162+
163+
def test_connect_uri_passes_dsn(mock_pgexecute: MagicMock) -> None:
164+
"""connect_uri should pass the URI as dsn parameter."""
165+
uri = "postgresql://user:pass@localhost/mydb"
166+
pgcli = PGCli()
167+
pgcli.connect_uri(uri)
168+
169+
call_args, call_kwargs = mock_pgexecute.call_args
170+
# dsn is passed as the 6th positional arg to PGExecute.__init__
171+
assert call_args[5] == uri
114172

115173

116174
def test_cli_with_tunnel() -> None:

0 commit comments

Comments
 (0)