@@ -50,15 +50,13 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
5050 mock_pgexecute .assert_called_once ()
5151
5252 call_args , call_kwargs = mock_pgexecute .call_args
53- assert call_args == (
54- db_params ["database" ],
55- db_params ["user" ],
56- db_params ["passwd" ],
57- "127.0.0.1" ,
58- pgcli .ssh_tunnel .local_bind_ports [0 ],
59- "" ,
60- notify_callback ,
61- )
53+ # Original host is preserved for .pgpass lookup, hostaddr has tunnel endpoint
54+ assert call_args [0 ] == db_params ["database" ] # database
55+ assert call_args [1 ] == db_params ["user" ] # user
56+ assert call_args [2 ] == db_params ["passwd" ] # passwd
57+ assert call_args [3 ] == db_params ["host" ] # original host preserved
58+ assert call_args [4 ] == pgcli .ssh_tunnel .local_bind_ports [0 ] # tunnel port
59+ assert call_kwargs .get ("hostaddr" ) == "127.0.0.1" # tunnel endpoint via hostaddr
6260 mock_ssh_tunnel_forwarder .reset_mock ()
6361 mock_pgexecute .reset_mock ()
6462
@@ -86,15 +84,10 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
8684 mock_pgexecute .assert_called_once ()
8785
8886 call_args , call_kwargs = mock_pgexecute .call_args
89- assert call_args == (
90- db_params ["database" ],
91- db_params ["user" ],
92- db_params ["passwd" ],
93- "127.0.0.1" ,
94- pgcli .ssh_tunnel .local_bind_ports [0 ],
95- "" ,
96- notify_callback ,
97- )
87+ # Original host is preserved, hostaddr has tunnel endpoint
88+ assert call_args [3 ] == db_params ["host" ] # original host preserved
89+ assert call_args [4 ] == pgcli .ssh_tunnel .local_bind_ports [0 ] # tunnel port
90+ assert call_kwargs .get ("hostaddr" ) == "127.0.0.1"
9891 mock_ssh_tunnel_forwarder .reset_mock ()
9992 mock_pgexecute .reset_mock ()
10093
@@ -104,13 +97,78 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM
10497 pgcli = PGCli (ssh_tunnel_url = tunnel_url )
10598 pgcli .connect (dsn = dsn )
10699
107- expected_dsn = f"user={ db_params ['user' ]} password={ db_params ['passwd' ]} host=127.0.0.1 port={ pgcli .ssh_tunnel .local_bind_ports [0 ]} "
108-
109100 mock_ssh_tunnel_forwarder .assert_called_once_with (** expected_tunnel_params )
110101 mock_pgexecute .assert_called_once ()
111102
112103 call_args , call_kwargs = mock_pgexecute .call_args
113- assert expected_dsn in call_args
104+ # DSN should contain original host AND hostaddr for tunnel
105+ dsn_arg = call_args [5 ] # dsn is 6th positional arg
106+ assert f"host={ db_params ['host' ]} " in dsn_arg
107+ assert "hostaddr=127.0.0.1" in dsn_arg
108+ assert f"port={ pgcli .ssh_tunnel .local_bind_ports [0 ]} " in dsn_arg
109+
110+
111+ def test_ssh_tunnel_preserves_original_host_for_pgpass (
112+ mock_ssh_tunnel_forwarder : MagicMock , mock_pgexecute : MagicMock
113+ ) -> None :
114+ """Verify that the original hostname is preserved for .pgpass lookup."""
115+ tunnel_url = "bastion.example.com"
116+ original_host = "production.db.example.com"
117+
118+ pgcli = PGCli (ssh_tunnel_url = tunnel_url )
119+ pgcli .connect (database = "mydb" , host = original_host , user = "dbuser" , passwd = "dbpass" )
120+
121+ call_args , call_kwargs = mock_pgexecute .call_args
122+ # host parameter should be the original, not 127.0.0.1
123+ assert call_args [3 ] == original_host
124+ # hostaddr should be the tunnel endpoint
125+ assert call_kwargs .get ("hostaddr" ) == "127.0.0.1"
126+
127+
128+ def test_ssh_tunnel_with_dsn_preserves_host (
129+ mock_ssh_tunnel_forwarder : MagicMock , mock_pgexecute : MagicMock
130+ ) -> None :
131+ """DSN connections should include hostaddr for tunnel while preserving host."""
132+ tunnel_url = "bastion.example.com"
133+ dsn = "host=production.db.example.com port=5432 dbname=mydb user=dbuser"
134+
135+ pgcli = PGCli (ssh_tunnel_url = tunnel_url )
136+ pgcli .connect (dsn = dsn )
137+
138+ call_args , call_kwargs = mock_pgexecute .call_args
139+ dsn_arg = call_args [5 ]
140+ assert "host=production.db.example.com" in dsn_arg
141+ assert "hostaddr=127.0.0.1" in dsn_arg
142+
143+
144+ def test_no_ssh_tunnel_does_not_set_hostaddr (mock_pgexecute : MagicMock ) -> None :
145+ """Without SSH tunnel, hostaddr should not be set."""
146+ pgcli = PGCli ()
147+ pgcli .connect (database = "mydb" , host = "localhost" , user = "user" , passwd = "pass" )
148+
149+ call_args , call_kwargs = mock_pgexecute .call_args
150+ assert "hostaddr" not in call_kwargs
151+
152+
153+ def test_connect_uri_without_ssh_tunnel (mock_pgexecute : MagicMock ) -> None :
154+ """connect_uri should work normally without SSH tunnel."""
155+ pgcli = PGCli ()
156+ pgcli .connect_uri ("postgresql://user:pass@localhost/mydb" )
157+
158+ mock_pgexecute .assert_called_once ()
159+ call_args , call_kwargs = mock_pgexecute .call_args
160+ assert "hostaddr" not in call_kwargs
161+
162+
163+ def test_connect_uri_passes_dsn (mock_pgexecute : MagicMock ) -> None :
164+ """connect_uri should pass the URI as dsn parameter."""
165+ uri = "postgresql://user:pass@localhost/mydb"
166+ pgcli = PGCli ()
167+ pgcli .connect_uri (uri )
168+
169+ call_args , call_kwargs = mock_pgexecute .call_args
170+ # dsn is passed as the 6th positional arg to PGExecute.__init__
171+ assert call_args [5 ] == uri
114172
115173
116174def test_cli_with_tunnel () -> None :
0 commit comments