|
1 | 1 | import logging |
2 | 2 | import os |
3 | | -from unittest.mock import patch, MagicMock, ANY, call |
| 3 | +from unittest.mock import patch, MagicMock, ANY, call, mock_open |
4 | 4 |
|
| 5 | +import paramiko |
5 | 6 | import pytest |
6 | 7 | from configobj import ConfigObj |
7 | 8 | from click.testing import CliRunner |
@@ -544,6 +545,176 @@ def test_proxy_command_passed(self, mock_native_tunnel): |
544 | 545 | connect_kwargs = mock_native_tunnel["client"].connect.call_args[1] |
545 | 546 | assert connect_kwargs["sock"] is mock_proxy |
546 | 547 |
|
| 548 | + def test_key_filenames_passed_to_connect(self, mock_native_tunnel): |
| 549 | + """Test that key_filenames are passed as key_filename to connect().""" |
| 550 | + key_files = ["/home/user/.ssh/id_ed25519", "/home/user/.ssh/id_rsa"] |
| 551 | + tunnel = _NativeSSHTunnel( |
| 552 | + ssh_hostname="bastion", |
| 553 | + ssh_port=22, |
| 554 | + remote_host="db.internal", |
| 555 | + remote_port=5432, |
| 556 | + ssh_username="testuser", |
| 557 | + key_filenames=key_files, |
| 558 | + ) |
| 559 | + tunnel.start() |
| 560 | + |
| 561 | + connect_kwargs = mock_native_tunnel["client"].connect.call_args[1] |
| 562 | + assert connect_kwargs["key_filename"] == key_files |
| 563 | + assert connect_kwargs["look_for_keys"] is False # Still disabled |
| 564 | + |
| 565 | + def test_no_key_filenames_omits_key_filename(self, mock_native_tunnel): |
| 566 | + """Test that key_filename is NOT passed when key_filenames is None.""" |
| 567 | + tunnel = _NativeSSHTunnel( |
| 568 | + ssh_hostname="bastion", |
| 569 | + ssh_port=22, |
| 570 | + remote_host="db.internal", |
| 571 | + remote_port=5432, |
| 572 | + ssh_username="testuser", |
| 573 | + ) |
| 574 | + tunnel.start() |
| 575 | + |
| 576 | + connect_kwargs = mock_native_tunnel["client"].connect.call_args[1] |
| 577 | + assert "key_filename" not in connect_kwargs |
| 578 | + |
| 579 | + def test_host_key_policy_auto_add(self, mock_native_tunnel): |
| 580 | + """Test that auto-add policy sets AutoAddPolicy.""" |
| 581 | + tunnel = _NativeSSHTunnel( |
| 582 | + ssh_hostname="bastion", ssh_port=22, |
| 583 | + remote_host="db.internal", remote_port=5432, |
| 584 | + host_key_policy="auto-add", |
| 585 | + ) |
| 586 | + tunnel.start() |
| 587 | + policy_arg = mock_native_tunnel["client"].set_missing_host_key_policy.call_args[0][0] |
| 588 | + assert isinstance(policy_arg, paramiko.AutoAddPolicy) |
| 589 | + |
| 590 | + def test_host_key_policy_warn(self, mock_native_tunnel): |
| 591 | + """Test that warn policy sets WarningPolicy.""" |
| 592 | + tunnel = _NativeSSHTunnel( |
| 593 | + ssh_hostname="bastion", ssh_port=22, |
| 594 | + remote_host="db.internal", remote_port=5432, |
| 595 | + host_key_policy="warn", |
| 596 | + ) |
| 597 | + tunnel.start() |
| 598 | + policy_arg = mock_native_tunnel["client"].set_missing_host_key_policy.call_args[0][0] |
| 599 | + assert isinstance(policy_arg, paramiko.WarningPolicy) |
| 600 | + |
| 601 | + def test_host_key_policy_reject(self, mock_native_tunnel): |
| 602 | + """Test that reject policy sets RejectPolicy.""" |
| 603 | + tunnel = _NativeSSHTunnel( |
| 604 | + ssh_hostname="bastion", ssh_port=22, |
| 605 | + remote_host="db.internal", remote_port=5432, |
| 606 | + host_key_policy="reject", |
| 607 | + ) |
| 608 | + tunnel.start() |
| 609 | + policy_arg = mock_native_tunnel["client"].set_missing_host_key_policy.call_args[0][0] |
| 610 | + assert isinstance(policy_arg, paramiko.RejectPolicy) |
| 611 | + |
| 612 | + def test_host_key_policy_default_is_auto_add(self, mock_native_tunnel): |
| 613 | + """Test that default policy is auto-add.""" |
| 614 | + tunnel = _NativeSSHTunnel( |
| 615 | + ssh_hostname="bastion", ssh_port=22, |
| 616 | + remote_host="db.internal", remote_port=5432, |
| 617 | + ) |
| 618 | + tunnel.start() |
| 619 | + policy_arg = mock_native_tunnel["client"].set_missing_host_key_policy.call_args[0][0] |
| 620 | + assert isinstance(policy_arg, paramiko.AutoAddPolicy) |
| 621 | + |
| 622 | + def test_host_key_policy_invalid_falls_back_to_auto_add(self, mock_native_tunnel): |
| 623 | + """Test that invalid policy name falls back to AutoAddPolicy.""" |
| 624 | + tunnel = _NativeSSHTunnel( |
| 625 | + ssh_hostname="bastion", ssh_port=22, |
| 626 | + remote_host="db.internal", remote_port=5432, |
| 627 | + host_key_policy="nonsense", |
| 628 | + ) |
| 629 | + tunnel.start() |
| 630 | + policy_arg = mock_native_tunnel["client"].set_missing_host_key_policy.call_args[0][0] |
| 631 | + assert isinstance(policy_arg, paramiko.AutoAddPolicy) |
| 632 | + |
| 633 | + |
| 634 | +class TestSSHTunnelIdentityFile: |
| 635 | + """Tests for IdentityFile reading from SSH config.""" |
| 636 | + |
| 637 | + def _make_manager_with_ssh_config(self, mock_native_tunnel, host_config, tunnel_url="ssh://bastion.example.com"): |
| 638 | + """Helper: create manager, mock SSH config lookup, run start_tunnel.""" |
| 639 | + mock_ssh_config = MagicMock() |
| 640 | + mock_ssh_config.lookup.return_value = host_config |
| 641 | + |
| 642 | + # Determine which identity files "exist" on disk |
| 643 | + existing_files = set(host_config.get("_existing_files", host_config.get("identityfile", []))) |
| 644 | + existing_files.add("~/.ssh/config") # SSH config always exists |
| 645 | + |
| 646 | + manager = SSHTunnelManager( |
| 647 | + ssh_tunnel_url=tunnel_url, |
| 648 | + logger=logging.getLogger("test"), |
| 649 | + ) |
| 650 | + |
| 651 | + with patch("pgcli.ssh_tunnel.os.path.expanduser", side_effect=lambda p: p), \ |
| 652 | + patch("pgcli.ssh_tunnel.os.path.isfile", side_effect=lambda p: p in existing_files), \ |
| 653 | + patch("pgcli.ssh_tunnel.paramiko.SSHConfig") as mock_config_cls, \ |
| 654 | + patch("builtins.open", mock_open(read_data="")): |
| 655 | + mock_config_cls.return_value = mock_ssh_config |
| 656 | + host, port = manager.start_tunnel(host="db.internal", port=5432) |
| 657 | + |
| 658 | + return mock_native_tunnel["client"].connect.call_args[1] |
| 659 | + |
| 660 | + def test_start_tunnel_reads_identity_files(self, mock_native_tunnel): |
| 661 | + """Test that start_tunnel reads IdentityFile from SSH config and passes to connect.""" |
| 662 | + host_config = { |
| 663 | + "hostname": "bastion.example.com", |
| 664 | + "user": "tunneluser", |
| 665 | + "identityfile": ["/home/user/.ssh/id_ed25519_specific", "/home/user/.ssh/id_rsa_wildcard"], |
| 666 | + } |
| 667 | + |
| 668 | + connect_kwargs = self._make_manager_with_ssh_config(mock_native_tunnel, host_config) |
| 669 | + |
| 670 | + assert "key_filename" in connect_kwargs |
| 671 | + assert connect_kwargs["key_filename"] == [ |
| 672 | + "/home/user/.ssh/id_ed25519_specific", |
| 673 | + "/home/user/.ssh/id_rsa_wildcard", |
| 674 | + ] |
| 675 | + assert connect_kwargs["look_for_keys"] is False |
| 676 | + |
| 677 | + def test_start_tunnel_skips_nonexistent_identity_files(self, mock_native_tunnel): |
| 678 | + """Test that nonexistent IdentityFile entries are filtered out.""" |
| 679 | + host_config = { |
| 680 | + "hostname": "bastion.example.com", |
| 681 | + "identityfile": ["/home/user/.ssh/id_ed25519_exists", "/home/user/.ssh/id_rsa_missing"], |
| 682 | + "_existing_files": ["/home/user/.ssh/id_ed25519_exists"], # only this one exists |
| 683 | + } |
| 684 | + |
| 685 | + connect_kwargs = self._make_manager_with_ssh_config(mock_native_tunnel, host_config) |
| 686 | + |
| 687 | + assert "key_filename" in connect_kwargs |
| 688 | + assert connect_kwargs["key_filename"] == ["/home/user/.ssh/id_ed25519_exists"] |
| 689 | + |
| 690 | + def test_start_tunnel_no_identity_files_omits_key_filename(self, mock_native_tunnel): |
| 691 | + """Test that key_filename is omitted when SSH config has no IdentityFile.""" |
| 692 | + host_config = { |
| 693 | + "hostname": "bastion.example.com", |
| 694 | + "user": "tunneluser", |
| 695 | + } |
| 696 | + |
| 697 | + connect_kwargs = self._make_manager_with_ssh_config(mock_native_tunnel, host_config) |
| 698 | + |
| 699 | + assert "key_filename" not in connect_kwargs |
| 700 | + |
| 701 | + def test_identity_file_order_preserved(self, mock_native_tunnel): |
| 702 | + """Test that IdentityFile order is preserved (host-specific first, wildcard after).""" |
| 703 | + host_config = { |
| 704 | + "hostname": "bastion.example.com", |
| 705 | + "identityfile": [ |
| 706 | + "/home/user/.ssh/id_ed25519_host", # host-specific (first) |
| 707 | + "/home/user/.ssh/id_ed25519_global", # wildcard (second) |
| 708 | + ], |
| 709 | + } |
| 710 | + |
| 711 | + connect_kwargs = self._make_manager_with_ssh_config(mock_native_tunnel, host_config) |
| 712 | + |
| 713 | + assert connect_kwargs["key_filename"] == [ |
| 714 | + "/home/user/.ssh/id_ed25519_host", |
| 715 | + "/home/user/.ssh/id_ed25519_global", |
| 716 | + ] |
| 717 | + |
547 | 718 |
|
548 | 719 | class TestGetTunnelManagerFromConfig: |
549 | 720 | """Tests for get_tunnel_manager_from_config function.""" |
@@ -596,3 +767,14 @@ def test_allow_agent_from_config(self): |
596 | 767 | config = {"ssh tunnels": {"allow_agent": "False"}} |
597 | 768 | manager = get_tunnel_manager_from_config(config) |
598 | 769 | assert manager.allow_agent is False |
| 770 | + |
| 771 | + def test_host_key_policy_from_config(self): |
| 772 | + """Test host_key_policy is read from config.""" |
| 773 | + config = {"ssh tunnels": {"host_key_policy": "reject"}} |
| 774 | + manager = get_tunnel_manager_from_config(config) |
| 775 | + assert manager.host_key_policy == "reject" |
| 776 | + |
| 777 | + def test_host_key_policy_default(self): |
| 778 | + """Test host_key_policy defaults to auto-add.""" |
| 779 | + manager = get_tunnel_manager_from_config({}) |
| 780 | + assert manager.host_key_policy == "auto-add" |
0 commit comments