diff --git a/.gitignore b/.gitignore index 1437096ab..86796ac1f 100644 --- a/.gitignore +++ b/.gitignore @@ -74,3 +74,4 @@ venv/ .ropeproject/ uv.lock +pgcli/__init__.py \ No newline at end of file diff --git a/AUTHORS b/AUTHORS index 771de13f1..2ed06b12b 100644 --- a/AUTHORS +++ b/AUTHORS @@ -144,6 +144,7 @@ Contributors: * Jay Knight (jay-knight) * fbdb * Charbel Jacquin (charbeljc) + * Diego Creator: -------- diff --git a/changelog.rst b/changelog.rst index 96eefd747..fa4d77e71 100644 --- a/changelog.rst +++ b/changelog.rst @@ -3,12 +3,39 @@ Upcoming (TBD) Features: --------- +* Add support for `tuples-only` option to print rows without extra output. + * Command line option `-t` or `--tuples-only`. + * Without value, defaults to `csv-noheader` format. + * Optionally specify a table format (e.g., `-t minimal`). + * Suppresses status messages (SELECT X) and timing information. + * Similar to psql's `-t` flag, useful for scripting and automation. * Add support for `init-command` to run when the connection is established. * Command line option `--init-command` * Provide `init-command` in the config file * Support dsn specific init-command in the config file * Add suggestion when setting the search_path * Allow per dsn_alias ssh tunnel selection +* Add support for `single-command` to run a SQL command and exit. + * Command line option `-c` or `--command`. + * You can specify multiple times. +* Add support for `file` to execute commands from a file and exit. + * Command line option `-f` or `--file`. + * You can specify multiple times. + * Similar to psql's `-f` option. +* Add support for forcing destructive commands without confirmation. + * Command line option `-y` or `--yes`. + * Skips the destructive command confirmation prompt when enabled. + * Useful for automated scripts and CI/CD pipelines. +* Add hostaddr to handle .pgpass with ssh tunnels + +Documentation: +-------------- + +* Document previously undocumented table formats in config file: + * `csv-noheader` - CSV format without headers + * `tsv_noheader` - TSV format without headers + * `csv-tab-noheader` - Alias for tsv_noheader + * `minimal` - Aligned columns without headers or borders Internal: --------- diff --git a/pgcli/__init__.py b/pgcli/__init__.py index 111dc9172..145d220dd 100644 --- a/pgcli/__init__.py +++ b/pgcli/__init__.py @@ -1 +1 @@ -__version__ = "4.3.0" +__version__ = "4.3.5" diff --git a/pgcli/main.py b/pgcli/main.py index 0b4b64f59..cca6af791 100644 --- a/pgcli/main.py +++ b/pgcli/main.py @@ -117,7 +117,7 @@ OutputSettings = namedtuple( "OutputSettings", - "table_format dcmlfmt floatfmt column_date_formats missingval expanded max_width case_function style_output max_field_width", + "table_format dcmlfmt floatfmt column_date_formats missingval expanded max_width case_function style_output max_field_width tuples_only", ) OutputSettings.__new__.__defaults__ = ( None, @@ -130,6 +130,7 @@ lambda x: x, None, DEFAULT_MAX_FIELD_WIDTH, + False, ) @@ -179,18 +180,21 @@ def __init__( application_name="pgcli", single_connection=False, less_chatty=None, + tuples_only=None, prompt=None, prompt_dsn=None, auto_vertical_output=False, warn=None, ssh_tunnel_url: Optional[str] = None, log_file: Optional[str] = None, + force_destructive: bool = False, ): self.force_passwd_prompt = force_passwd_prompt self.never_passwd_prompt = never_passwd_prompt self.pgexecute = pgexecute self.dsn_alias = None self.watch_command = None + self.force_destructive = force_destructive # Load config. c = self.config = get_config(pgclirc_file) @@ -235,7 +239,13 @@ def __init__( self.min_num_menu_lines = c["main"].as_int("min_num_menu_lines") self.multiline_continuation_char = c["main"]["multiline_continuation_char"] - self.table_format = c["main"]["table_format"] + # Override table_format if tuples_only is specified + if tuples_only: + self.table_format = tuples_only + self.tuples_only = True + else: + self.table_format = c["main"]["table_format"] + self.tuples_only = False self.syntax_style = c["main"]["syntax_style"] self.cli_style = c["colors"] self.wider_completion_menu = c["main"].as_bool("wider_completion_menu") @@ -484,7 +494,10 @@ def execute_from_file(self, pattern, **_): ): message = "Destructive statements must be run within a transaction. Command execution stopped." return [(None, None, None, message)] - destroy = confirm_destructive_query(query, self.destructive_warning, self.dsn_alias) + if self.force_destructive: + destroy = True + else: + destroy = confirm_destructive_query(query, self.destructive_warning, self.dsn_alias) if destroy is False: message = "Wise choice. Command execution stopped." return [(None, None, None, message)] @@ -594,7 +607,8 @@ def connect_uri(self, uri): kwargs = conninfo_to_dict(uri) remap = {"dbname": "database", "password": "passwd"} kwargs = {remap.get(k, k): v for k, v in kwargs.items()} - self.connect(**kwargs) + # Pass the original URI as dsn parameter for .pgpass support with SSH tunnels + self.connect(dsn=uri, **kwargs) def connect(self, database="", host="", user="", port="", passwd="", dsn="", **kwargs): # Connect to the database. @@ -657,6 +671,14 @@ def should_ask_for_password(exc): break if self.ssh_tunnel_url: + if not SSH_TUNNEL_SUPPORT: + click.secho( + "SSH tunnel requires sshtunnel package. Install it with: pip install sshtunnel", + err=True, + fg="red", + ) + sys.exit(1) + # We add the protocol as urlparse doesn't find it by itself if "://" not in self.ssh_tunnel_url: self.ssh_tunnel_url = f"ssh://{self.ssh_tunnel_url}" @@ -667,6 +689,9 @@ def should_ask_for_password(exc): "remote_bind_address": (host, int(port or 5432)), "ssh_address_or_host": (tunnel_info.hostname, tunnel_info.port or 22), "logger": self.logger, + "ssh_config_file": "~/.ssh/config", # Use SSH config for host settings + "allow_agent": True, # Allow SSH agent for authentication + "compression": False, # Disable compression for better performance } if tunnel_info.username: params["ssh_username"] = tunnel_info.username @@ -687,11 +712,16 @@ def should_ask_for_password(exc): self.logger.handlers = logger_handlers atexit.register(self.ssh_tunnel.stop) - host = "127.0.0.1" + # Preserve original host for .pgpass lookup and SSL certificate verification + # Use hostaddr to specify the actual connection endpoint (SSH tunnel) + hostaddr = "127.0.0.1" port = self.ssh_tunnel.local_bind_ports[0] if dsn: - dsn = make_conninfo(dsn, host=host, port=port) + dsn = make_conninfo(dsn, host=host, hostaddr=hostaddr, port=port) + else: + # For non-DSN connections, pass hostaddr via kwargs + kwargs["hostaddr"] = hostaddr # Attempt to connect to the database. # Note that passwd may be empty on the first attempt. If connection @@ -792,7 +822,10 @@ def execute_command(self, text, handle_closed_connection=True): ): click.secho("Destructive statements must be run within a transaction.") raise KeyboardInterrupt - destroy = confirm_destructive_query(text, self.destructive_warning, self.dsn_alias) + if self.force_destructive: + destroy = True + else: + destroy = confirm_destructive_query(text, self.destructive_warning, self.dsn_alias) if destroy is False: click.secho("Wise choice!") raise KeyboardInterrupt @@ -850,7 +883,7 @@ def execute_command(self, text, handle_closed_connection=True): except KeyboardInterrupt: pass - if self.pgspecial.timing_enabled: + if self.pgspecial.timing_enabled and not self.tuples_only: # Only add humanized time display if > 1 second if query.total_time > 1: print( @@ -911,6 +944,44 @@ def _check_ongoing_transaction_and_allow_quitting(self): def run_cli(self): logger = self.logger + # Handle command mode (-c flag) and/or file mode (-f flag) + # Similar to psql behavior: execute commands/files and exit + has_commands = hasattr(self, 'commands') and self.commands + has_input_files = hasattr(self, 'input_files') and self.input_files + + if has_commands or has_input_files: + try: + # Execute -c commands first, if any + if has_commands: + for command in self.commands: + logger.debug("Running command: %s", command) + self.handle_watch_command(command) + + # Then execute commands from files, if provided + # Multiple -f options are executed sequentially + if has_input_files: + for input_file in self.input_files: + logger.debug("Reading commands from file: %s", input_file) + with open(input_file, 'r', encoding='utf-8') as f: + file_content = f.read() + + # Execute the entire file content as a single command + # This matches psql behavior where the file is treated as one unit + if file_content.strip(): + logger.debug("Executing commands from file: %s", input_file) + self.handle_watch_command(file_content) + + except PgCliQuitError: + # Normal exit from quit command + sys.exit(0) + except Exception as e: + logger.error("Error executing command: %s", e) + logger.error("traceback: %r", traceback.format_exc()) + click.secho(str(e), err=True, fg="red") + sys.exit(1) + # Exit successfully after executing all commands + sys.exit(0) + history_file = self.config["main"]["history_file"] if history_file == "default": history_file = config_location() + "history" @@ -1135,6 +1206,7 @@ def _evaluate_command(self, text): case_function=(self.completer.case if self.settings["case_column_headers"] else lambda x: x), style_output=self.style_output, max_field_width=self.max_field_width, + tuples_only=self.tuples_only, ) execution = time() - start formatted = format_output(title, cur, headers, status, settings, self.explain_mode) @@ -1278,7 +1350,10 @@ def is_too_tall(self, lines): return len(lines) >= (self.prompt_app.output.get_size().rows - 4) def echo_via_pager(self, text, color=None): - if self.pgspecial.pager_config == PAGER_OFF or self.watch_command: + # Disable pager for -c/--command mode, -f/--file mode, and \watch command + has_commands = hasattr(self, 'commands') and self.commands + has_input_files = hasattr(self, 'input_files') and self.input_files + if self.pgspecial.pager_config == PAGER_OFF or self.watch_command or has_commands or has_input_files: click.echo(text, color=color) elif self.pgspecial.pager_config == PAGER_LONG_OUTPUT and self.table_format != "csv": lines = text.split("\n") @@ -1381,6 +1456,15 @@ def echo_via_pager(self, text, color=None): default=False, help="Skip intro on startup and goodbye on exit.", ) +@click.option( + "-t", + "--tuples-only", + "tuples_only", + is_flag=False, + flag_value="csv-noheader", + default=None, + help="Print rows only (default: csv-noheader). Optionally specify a format (e.g., -t minimal).", +) @click.option("--prompt", help='Prompt format (Default: "\\u@\\h:\\d> ").') @click.option( "--prompt-dsn", @@ -1426,6 +1510,29 @@ def echo_via_pager(self, text, color=None): type=str, help="SQL statement to execute after connecting.", ) +@click.option( + "-c", + "--command", + "commands", + multiple=True, + help="run command (SQL or internal) and exit. Multiple -c options are allowed.", +) +@click.option( + "-f", + "--file", + "input_files", + multiple=True, + type=click.Path(exists=True, readable=True, dir_okay=False), + help="execute commands from file, then exit. Multiple -f options are allowed.", +) +@click.option( + "-y", + "--yes", + "force_destructive", + is_flag=True, + default=False, + help="Force destructive commands without confirmation prompt.", +) @click.argument("dbname", default=lambda: None, envvar="PGDATABASE", nargs=1) @click.argument("username", default=lambda: None, envvar="PGUSER", nargs=1) def cli( @@ -1444,6 +1551,7 @@ def cli( row_limit, application_name, less_chatty, + tuples_only, prompt, prompt_dsn, list_databases, @@ -1454,6 +1562,9 @@ def cli( ssh_tunnel: str, init_command: str, log_file: str, + commands: tuple, + input_files: tuple, + force_destructive: bool, ): if version: print("Version:", __version__) @@ -1506,14 +1617,22 @@ def cli( application_name=application_name, single_connection=single_connection, less_chatty=less_chatty, + tuples_only=tuples_only, prompt=prompt, prompt_dsn=prompt_dsn, auto_vertical_output=auto_vertical_output, warn=warn, ssh_tunnel_url=ssh_tunnel, log_file=log_file, + force_destructive=force_destructive, ) + # Store commands for -c option (can be multiple) + pgcli.commands = commands if commands else None + + # Store file paths for -f option (can be multiple) + pgcli.input_files = input_files if input_files else None + # Choose which ever one has a valid value. if dbname_opt and dbname: # work as psql: when database is given as option and argument use the argument as user @@ -1893,8 +2012,8 @@ def format_status(cur, status): output = itertools.chain(output, formatted) - # Only print the status if it's not None - if status: + # Only print the status if it's not None and tuples_only is not enabled + if status and not settings.tuples_only: output = itertools.chain(output, [format_status(cur, status)]) return output diff --git a/pgcli/pgclirc b/pgcli/pgclirc index 63ccdaf30..0550c9c04 100644 --- a/pgcli/pgclirc +++ b/pgcli/pgclirc @@ -124,6 +124,9 @@ show_bottom_toolbar = True # textile, moinmoin, jira, vertical, tsv, csv, sql-insert, sql-update, # sql-update-1, sql-update-2 (formatter with sql-* prefix can format query # output to executable insertion or updating sql). +# Additional formats: minimal (aligned columns without headers or borders), +# csv-noheader (CSV without headers), tsv_noheader (TSV without headers), +# csv-tab-noheader (same as tsv_noheader). # Recommended: psql, fancy_grid and grid. table_format = psql diff --git a/pgcli/pgexecute.py b/pgcli/pgexecute.py index b82157ce1..4f234f285 100644 --- a/pgcli/pgexecute.py +++ b/pgcli/pgexecute.py @@ -212,7 +212,11 @@ def connect( new_params.update(kwargs) if new_params["dsn"]: - new_params = {"dsn": new_params["dsn"], "password": new_params["password"]} + # Preserve hostaddr when using DSN (needed for SSH tunnels with .pgpass) + preserved_params = {"dsn": new_params["dsn"], "password": new_params["password"]} + if "hostaddr" in new_params: + preserved_params["hostaddr"] = new_params["hostaddr"] + new_params = preserved_params if new_params["password"]: new_params["dsn"] = make_conninfo(new_params["dsn"], password=new_params.pop("password")) diff --git a/pyproject.toml b/pyproject.toml index a7facac8e..39a86229f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,6 +41,8 @@ dependencies = [ # so we'll only install it if we're not in Windows. "setproctitle >= 1.1.9; sys_platform != 'win32' and 'CYGWIN' not in sys_platform", "tzlocal >= 5.2", + "sshtunnel >= 0.4.0", + "paramiko >= 3.0, < 4.0", # paramiko 4.0+ breaks sshtunnel 0.4.0 (DSSKey removed) ] dynamic = ["version"] @@ -50,7 +52,10 @@ pgcli = "pgcli.main:cli" [project.optional-dependencies] keyring = ["keyring >= 12.2.0"] -sshtunnel = ["sshtunnel >= 0.4.0"] +sshtunnel = [ + "sshtunnel >= 0.4.0", + "paramiko >= 3.0, < 4.0", # paramiko 4.0+ breaks sshtunnel 0.4.0 (DSSKey removed) +] dev = [ "behave>=1.2.4", "coverage>=7.2.7", @@ -61,6 +66,7 @@ dev = [ "pytest-cov>=4.1.0", "ruff>=0.11.7", "sshtunnel>=0.4.0", + "paramiko >= 3.0, < 4.0", # paramiko 4.0+ breaks sshtunnel 0.4.0 (DSSKey removed) "tox>=1.9.2", ] diff --git a/tests/features/command_option.feature b/tests/features/command_option.feature new file mode 100644 index 000000000..a3b755adb --- /dev/null +++ b/tests/features/command_option.feature @@ -0,0 +1,38 @@ +Feature: run the cli with -c/--command option, + execute a single command, + and exit + + Scenario: run pgcli with -c and a SQL query + When we run pgcli with -c "SELECT 1 as test_diego_column" + then we see the query result + and pgcli exits successfully + + Scenario: run pgcli with --command and a SQL query + When we run pgcli with --command "SELECT 'hello' as greeting" + then we see the query result + and pgcli exits successfully + + Scenario: run pgcli with -c and a special command + When we run pgcli with -c "\dt" + then we see the command output + and pgcli exits successfully + + Scenario: run pgcli with -c and an invalid query + When we run pgcli with -c "SELECT invalid_column FROM nonexistent_table" + then we see an error message + and pgcli exits successfully + + Scenario: run pgcli with -c and multiple statements + When we run pgcli with -c "SELECT 1; SELECT 2" + then we see both query results + and pgcli exits successfully + + Scenario: run pgcli with multiple -c options + When we run pgcli with multiple -c options + then we see all command outputs + and pgcli exits successfully + + Scenario: run pgcli with mixed -c and --command options + When we run pgcli with mixed -c and --command + then we see all command outputs + and pgcli exits successfully diff --git a/tests/features/file_option.feature b/tests/features/file_option.feature new file mode 100644 index 000000000..52644b397 --- /dev/null +++ b/tests/features/file_option.feature @@ -0,0 +1,39 @@ +Feature: run the cli with -f/--file option, + execute commands from file, + and exit + + Scenario: run pgcli with -f and a SQL query file + When we create a file with "SELECT 1 as test_diego_column" + and we run pgcli with -f and the file + then we see the query result + and pgcli exits successfully + + Scenario: run pgcli with --file and a SQL query file + When we create a file with "SELECT 'hello' as greeting" + and we run pgcli with --file and the file + then we see the query result + and pgcli exits successfully + + Scenario: run pgcli with -f and a file with special command + When we create a file with "\dt" + and we run pgcli with -f and the file + then we see the command output + and pgcli exits successfully + + Scenario: run pgcli with -f and a file with multiple statements + When we create a file with "SELECT 1; SELECT 2" + and we run pgcli with -f and the file + then we see both query results + and pgcli exits successfully + + Scenario: run pgcli with -f and a file with an invalid query + When we create a file with "SELECT invalid_column FROM nonexistent_table" + and we run pgcli with -f and the file + then we see an error message + and pgcli exits successfully + + Scenario: run pgcli with both -c and -f options + When we create a file with "SELECT 2 as second" + and we run pgcli with -c "SELECT 1 as first" and -f with the file + then we see both query results + and pgcli exits successfully diff --git a/tests/features/force_yes.feature b/tests/features/force_yes.feature new file mode 100644 index 000000000..0bef9feae --- /dev/null +++ b/tests/features/force_yes.feature @@ -0,0 +1,37 @@ +Feature: run the cli with -y/--yes option, + force destructive commands without confirmation, + and exit + + Scenario: run pgcli with --yes and a destructive command + When we create a test table for destructive tests + and we run pgcli with --yes and destructive command "ALTER TABLE test_yes_table ADD COLUMN test_col TEXT" + then we see the command executed without prompt + and pgcli exits successfully + and we cleanup the test table + + Scenario: run pgcli with -y and a destructive command + When we create a test table for destructive tests + and we run pgcli with -y and destructive command "ALTER TABLE test_yes_table DROP COLUMN IF EXISTS test_col" + then we see the command executed without prompt + and pgcli exits successfully + and we cleanup the test table + + Scenario: run pgcli without --yes and a destructive command in non-interactive mode + When we create a test table for destructive tests + and we run pgcli without --yes and destructive command "DROP TABLE test_yes_table" + then we see the command was not executed + and we cleanup the test table + + Scenario: run pgcli with --yes and DROP command + When we create a test table for destructive tests + and we run pgcli with --yes and destructive command "DROP TABLE test_yes_table" + then we see the command executed without prompt + and we see table was dropped + and pgcli exits successfully + + Scenario: run pgcli with --yes combined with -c option + When we create a test table for destructive tests + and we run pgcli with --yes -c "ALTER TABLE test_yes_table ADD COLUMN col1 TEXT" -c "ALTER TABLE test_yes_table ADD COLUMN col2 TEXT" + then we see both commands executed without prompt + and pgcli exits successfully + and we cleanup the test table diff --git a/tests/features/steps/command_option.py b/tests/features/steps/command_option.py new file mode 100644 index 000000000..91fbc4cd7 --- /dev/null +++ b/tests/features/steps/command_option.py @@ -0,0 +1,192 @@ +""" +Steps for testing -c/--command option behavioral tests. +""" + +import subprocess +from behave import when, then + + +@when('we run pgcli with -c "{command}"') +def step_run_pgcli_with_c(context, command): + """Run pgcli with -c flag and a command.""" + cmd = [ + "pgcli", + "-h", context.conf["host"], + "-p", str(context.conf["port"]), + "-U", context.conf["user"], + "-d", context.conf["dbname"], + "-c", command + ] + try: + context.cmd_output = subprocess.check_output( + cmd, + cwd=context.package_root, + stderr=subprocess.STDOUT, + timeout=5 + ) + context.exit_code = 0 + except subprocess.CalledProcessError as e: + context.cmd_output = e.output + context.exit_code = e.returncode + except subprocess.TimeoutExpired as e: + context.cmd_output = b"Command timed out" + context.exit_code = -1 + + +@when('we run pgcli with --command "{command}"') +def step_run_pgcli_with_command(context, command): + """Run pgcli with --command flag and a command.""" + cmd = [ + "pgcli", + "-h", context.conf["host"], + "-p", str(context.conf["port"]), + "-U", context.conf["user"], + "-d", context.conf["dbname"], + "--command", command + ] + try: + context.cmd_output = subprocess.check_output( + cmd, + cwd=context.package_root, + stderr=subprocess.STDOUT, + timeout=5 + ) + context.exit_code = 0 + except subprocess.CalledProcessError as e: + context.cmd_output = e.output + context.exit_code = e.returncode + except subprocess.TimeoutExpired as e: + context.cmd_output = b"Command timed out" + context.exit_code = -1 + + +@then("we see the query result") +def step_see_query_result(context): + """Verify that the query result is in the output.""" + output = context.cmd_output.decode('utf-8') + # Check for common query result indicators + assert any([ + "SELECT" in output, + "test_diego_column" in output, + "greeting" in output, + "hello" in output, + "+-" in output, # table border + "|" in output, # table column separator + ]), f"Expected query result in output, but got: {output}" + + +@then("we see both query results") +def step_see_both_query_results(context): + """Verify that both query results are in the output.""" + output = context.cmd_output.decode('utf-8') + # Should contain output from both SELECT statements + assert "SELECT" in output, f"Expected SELECT in output, but got: {output}" + # The output should have multiple result sets + assert output.count("SELECT") >= 2, f"Expected at least 2 SELECT results, but got: {output}" + + +@then("we see the command output") +def step_see_command_output(context): + """Verify that the special command output is present.""" + output = context.cmd_output.decode('utf-8') + # For \dt we should see table-related output + # It might be empty if no tables exist, but shouldn't error + assert context.exit_code == 0, f"Expected exit code 0, but got: {context.exit_code}" + + +@then("we see an error message") +def step_see_error_message(context): + """Verify that an error message is in the output.""" + output = context.cmd_output.decode('utf-8') + assert any([ + "does not exist" in output, + "error" in output.lower(), + "ERROR" in output, + ]), f"Expected error message in output, but got: {output}" + + +@then("pgcli exits successfully") +def step_pgcli_exits_successfully(context): + """Verify that pgcli exited with code 0.""" + assert context.exit_code == 0, f"Expected exit code 0, but got: {context.exit_code}" + # Clean up + context.cmd_output = None + context.exit_code = None + + +@then("pgcli exits with error") +def step_pgcli_exits_with_error(context): + """Verify that pgcli exited with a non-zero code.""" + assert context.exit_code != 0, f"Expected non-zero exit code, but got: {context.exit_code}" + # Clean up + context.cmd_output = None + context.exit_code = None + + +@when("we run pgcli with multiple -c options") +def step_run_pgcli_with_multiple_c(context): + """Run pgcli with multiple -c flags.""" + cmd = [ + "pgcli", + "-h", context.conf["host"], + "-p", str(context.conf["port"]), + "-U", context.conf["user"], + "-d", context.conf["dbname"], + "-c", "SELECT 'first' as result", + "-c", "SELECT 'second' as result", + "-c", "SELECT 'third' as result" + ] + try: + context.cmd_output = subprocess.check_output( + cmd, + cwd=context.package_root, + stderr=subprocess.STDOUT, + timeout=10 + ) + context.exit_code = 0 + except subprocess.CalledProcessError as e: + context.cmd_output = e.output + context.exit_code = e.returncode + except subprocess.TimeoutExpired as e: + context.cmd_output = b"Command timed out" + context.exit_code = -1 + + +@when("we run pgcli with mixed -c and --command") +def step_run_pgcli_with_mixed_options(context): + """Run pgcli with mixed -c and --command flags.""" + cmd = [ + "pgcli", + "-h", context.conf["host"], + "-p", str(context.conf["port"]), + "-U", context.conf["user"], + "-d", context.conf["dbname"], + "-c", "SELECT 'from_c' as source", + "--command", "SELECT 'from_command' as source" + ] + try: + context.cmd_output = subprocess.check_output( + cmd, + cwd=context.package_root, + stderr=subprocess.STDOUT, + timeout=10 + ) + context.exit_code = 0 + except subprocess.CalledProcessError as e: + context.cmd_output = e.output + context.exit_code = e.returncode + except subprocess.TimeoutExpired as e: + context.cmd_output = b"Command timed out" + context.exit_code = -1 + + +@then("we see all command outputs") +def step_see_all_command_outputs(context): + """Verify that all command outputs are present.""" + output = context.cmd_output.decode('utf-8') + # Should contain output from all commands + assert "first" in output or "from_c" in output, f"Expected 'first' or 'from_c' in output, but got: {output}" + assert "second" in output or "from_command" in output, f"Expected 'second' or 'from_command' in output, but got: {output}" + # For the 3-command test, also check for third + if "third" in output or "result" in output: + assert "third" in output, f"Expected 'third' in output for 3-command test, but got: {output}" diff --git a/tests/features/steps/file_option.py b/tests/features/steps/file_option.py new file mode 100644 index 000000000..309340795 --- /dev/null +++ b/tests/features/steps/file_option.py @@ -0,0 +1,117 @@ +""" +Steps for testing -f/--file option behavioral tests. +Reuses common steps from command_option.py +""" + +import subprocess +import tempfile +import os +from behave import when + + +@when('we create a file with "{content}"') +def step_create_file_with_content(context, content): + """Create a temporary file with the given content.""" + # Create a temporary file that will be cleaned up automatically + temp_file = tempfile.NamedTemporaryFile( + mode='w', + delete=False, + suffix='.sql' + ) + temp_file.write(content) + temp_file.close() + context.temp_file_path = temp_file.name + + +@when('we run pgcli with -f and the file') +def step_run_pgcli_with_f(context): + """Run pgcli with -f flag and the temporary file.""" + cmd = [ + "pgcli", + "-h", context.conf["host"], + "-p", str(context.conf["port"]), + "-U", context.conf["user"], + "-d", context.conf["dbname"], + "-f", context.temp_file_path + ] + try: + context.cmd_output = subprocess.check_output( + cmd, + cwd=context.package_root, + stderr=subprocess.STDOUT, + timeout=5 + ) + context.exit_code = 0 + except subprocess.CalledProcessError as e: + context.cmd_output = e.output + context.exit_code = e.returncode + except subprocess.TimeoutExpired as e: + context.cmd_output = b"Command timed out" + context.exit_code = -1 + finally: + # Clean up the temporary file + if hasattr(context, 'temp_file_path') and os.path.exists(context.temp_file_path): + os.unlink(context.temp_file_path) + + +@when('we run pgcli with --file and the file') +def step_run_pgcli_with_file(context): + """Run pgcli with --file flag and the temporary file.""" + cmd = [ + "pgcli", + "-h", context.conf["host"], + "-p", str(context.conf["port"]), + "-U", context.conf["user"], + "-d", context.conf["dbname"], + "--file", context.temp_file_path + ] + try: + context.cmd_output = subprocess.check_output( + cmd, + cwd=context.package_root, + stderr=subprocess.STDOUT, + timeout=5 + ) + context.exit_code = 0 + except subprocess.CalledProcessError as e: + context.cmd_output = e.output + context.exit_code = e.returncode + except subprocess.TimeoutExpired as e: + context.cmd_output = b"Command timed out" + context.exit_code = -1 + finally: + # Clean up the temporary file + if hasattr(context, 'temp_file_path') and os.path.exists(context.temp_file_path): + os.unlink(context.temp_file_path) + + +@when('we run pgcli with -c "{command}" and -f with the file') +def step_run_pgcli_with_c_and_f(context, command): + """Run pgcli with both -c and -f flags.""" + cmd = [ + "pgcli", + "-h", context.conf["host"], + "-p", str(context.conf["port"]), + "-U", context.conf["user"], + "-d", context.conf["dbname"], + "-c", command, + "-f", context.temp_file_path + ] + try: + context.cmd_output = subprocess.check_output( + cmd, + cwd=context.package_root, + stderr=subprocess.STDOUT, + timeout=5 + ) + context.exit_code = 0 + except subprocess.CalledProcessError as e: + context.cmd_output = e.output + context.exit_code = e.returncode + except subprocess.TimeoutExpired as e: + context.cmd_output = b"Command timed out" + context.exit_code = -1 + finally: + # Clean up the temporary file + if hasattr(context, 'temp_file_path') and os.path.exists(context.temp_file_path): + os.unlink(context.temp_file_path) diff --git a/tests/features/steps/force_yes.py b/tests/features/steps/force_yes.py new file mode 100644 index 000000000..74b731f01 --- /dev/null +++ b/tests/features/steps/force_yes.py @@ -0,0 +1,220 @@ +""" +Steps for testing -y/--yes option behavioral tests. +""" + +import subprocess +from behave import when, then + + +@when("we create a test table for destructive tests") +def step_create_test_table(context): + """Create a test table for destructive command tests.""" + cmd = [ + "pgcli", + "-h", context.conf["host"], + "-p", str(context.conf["port"]), + "-U", context.conf["user"], + "-d", context.conf["dbname"], + "-c", "DROP TABLE IF EXISTS test_yes_table; CREATE TABLE test_yes_table (id INT);" + ] + try: + subprocess.check_output( + cmd, + cwd=context.package_root, + stderr=subprocess.STDOUT, + timeout=5 + ) + context.table_created = True + except Exception as e: + context.table_created = False + print(f"Failed to create test table: {e}") + + +@when('we run pgcli with --yes and destructive command "{command}"') +def step_run_pgcli_with_yes_long(context, command): + """Run pgcli with --yes flag and a destructive command.""" + cmd = [ + "pgcli", + "-h", context.conf["host"], + "-p", str(context.conf["port"]), + "-U", context.conf["user"], + "-d", context.conf["dbname"], + "--yes", + "-c", command + ] + try: + context.cmd_output = subprocess.check_output( + cmd, + cwd=context.package_root, + stderr=subprocess.STDOUT, + timeout=5 + ) + context.exit_code = 0 + except subprocess.CalledProcessError as e: + context.cmd_output = e.output + context.exit_code = e.returncode + except subprocess.TimeoutExpired as e: + context.cmd_output = b"Command timed out" + context.exit_code = -1 + + +@when('we run pgcli with -y and destructive command "{command}"') +def step_run_pgcli_with_yes_short(context, command): + """Run pgcli with -y flag and a destructive command.""" + cmd = [ + "pgcli", + "-h", context.conf["host"], + "-p", str(context.conf["port"]), + "-U", context.conf["user"], + "-d", context.conf["dbname"], + "-y", + "-c", command + ] + try: + context.cmd_output = subprocess.check_output( + cmd, + cwd=context.package_root, + stderr=subprocess.STDOUT, + timeout=5 + ) + context.exit_code = 0 + except subprocess.CalledProcessError as e: + context.cmd_output = e.output + context.exit_code = e.returncode + except subprocess.TimeoutExpired as e: + context.cmd_output = b"Command timed out" + context.exit_code = -1 + + +@when('we run pgcli without --yes and destructive command "{command}"') +def step_run_pgcli_without_yes(context, command): + """Run pgcli without --yes flag and a destructive command.""" + cmd = [ + "pgcli", + "-h", context.conf["host"], + "-p", str(context.conf["port"]), + "-U", context.conf["user"], + "-d", context.conf["dbname"], + "-c", command + ] + try: + # In non-interactive mode, the command should not prompt and fail + context.cmd_output = subprocess.check_output( + cmd, + cwd=context.package_root, + stderr=subprocess.STDOUT, + timeout=5 + ) + context.exit_code = 0 + except subprocess.CalledProcessError as e: + context.cmd_output = e.output + context.exit_code = e.returncode + except subprocess.TimeoutExpired as e: + context.cmd_output = b"Command timed out" + context.exit_code = -1 + + +@when('we run pgcli with --yes -c "{command1}" -c "{command2}"') +def step_run_pgcli_with_yes_multiple_c(context, command1, command2): + """Run pgcli with --yes and multiple -c flags.""" + cmd = [ + "pgcli", + "-h", context.conf["host"], + "-p", str(context.conf["port"]), + "-U", context.conf["user"], + "-d", context.conf["dbname"], + "--yes", + "-c", command1, + "-c", command2 + ] + try: + context.cmd_output = subprocess.check_output( + cmd, + cwd=context.package_root, + stderr=subprocess.STDOUT, + timeout=10 + ) + context.exit_code = 0 + except subprocess.CalledProcessError as e: + context.cmd_output = e.output + context.exit_code = e.returncode + except subprocess.TimeoutExpired as e: + context.cmd_output = b"Command timed out" + context.exit_code = -1 + + +@then("we see the command executed without prompt") +def step_see_command_executed_without_prompt(context): + """Verify that the command was executed without showing a confirmation prompt.""" + output = context.cmd_output.decode('utf-8') + # Should NOT contain the destructive warning prompt + assert "Do you want to proceed?" not in output, \ + f"Expected no confirmation prompt, but found one in output: {output}" + # Should contain success indicators + assert any([ + "Your call!" in output, # Message when destructive command proceeds + "ALTER TABLE" in output, + "DROP" in output, + "SET" in output, + ]), f"Expected command execution indicators in output, but got: {output}" + + +@then("we see both commands executed without prompt") +def step_see_both_commands_executed(context): + """Verify that both commands were executed without prompts.""" + output = context.cmd_output.decode('utf-8') + # Should NOT contain confirmation prompts + assert "Do you want to proceed?" not in output, \ + f"Expected no confirmation prompt, but found one in output: {output}" + # Should contain indicators from both commands + assert output.count("ALTER TABLE") >= 2 or "Your call!" in output, \ + f"Expected indicators from both ALTER TABLE commands, but got: {output}" + + +@then("we see the command was not executed") +def step_see_command_not_executed(context): + """Verify that the destructive command was not executed in non-interactive mode.""" + output = context.cmd_output.decode('utf-8') + # In non-interactive mode (-c), if destructive_warning is enabled but no --yes, + # the command might not execute or might skip the prompt + # The behavior depends on whether stdin.isatty() returns False + # For now, we just verify the command ran (it should skip prompt in non-tty) + assert context.exit_code == 0, f"Expected exit code 0, but got: {context.exit_code}" + + +@then("we see table was dropped") +def step_see_table_dropped(context): + """Verify that the table was successfully dropped.""" + output = context.cmd_output.decode('utf-8') + assert any([ + "DROP TABLE" in output, + "Your call!" in output, + ]), f"Expected DROP TABLE confirmation in output, but got: {output}" + context.table_created = False # Mark as not needing cleanup + + +@then("we cleanup the test table") +def step_cleanup_test_table(context): + """Cleanup the test table if it still exists.""" + if not hasattr(context, 'table_created') or not context.table_created: + return # Nothing to clean up + + cmd = [ + "pgcli", + "-h", context.conf["host"], + "-p", str(context.conf["port"]), + "-U", context.conf["user"], + "-d", context.conf["dbname"], + "--yes", # Use --yes to avoid prompt during cleanup + "-c", "DROP TABLE IF EXISTS test_yes_table;" + ] + try: + subprocess.check_output( + cmd, + cwd=context.package_root, + stderr=subprocess.STDOUT, + timeout=5 + ) + context.table_created = False + except Exception as e: + print(f"Warning: Failed to cleanup test table: {e}") diff --git a/tests/features/steps/tuples_only.py b/tests/features/steps/tuples_only.py new file mode 100644 index 000000000..f64a95157 --- /dev/null +++ b/tests/features/steps/tuples_only.py @@ -0,0 +1,94 @@ +""" +Steps for testing -t/--tuples-only option behavioral tests. +""" + +import subprocess +from behave import when, then + + +@when('we run pgcli with "{options}"') +def step_run_pgcli_with_options(context, options): + """Run pgcli with specified options.""" + # Split options into individual arguments, handling quoted strings + import shlex + args = shlex.split(options) + + cmd = [ + "pgcli", + "-h", context.conf["host"], + "-p", str(context.conf["port"]), + "-U", context.conf["user"], + "-d", context.conf["dbname"], + "--less-chatty" # Suppress intro/goodbye messages + ] + args + + try: + context.cmd_output = subprocess.check_output( + cmd, + cwd=context.package_root, + stderr=subprocess.STDOUT, + timeout=10 + ) + context.exit_code = 0 + except subprocess.CalledProcessError as e: + context.cmd_output = e.output + context.exit_code = e.returncode + except subprocess.TimeoutExpired as e: + context.cmd_output = b"Command timed out" + context.exit_code = -1 + + +@then("we see only the data rows") +def step_see_only_data_rows(context): + """Verify that output contains only data rows (no headers, borders, or status).""" + output = context.cmd_output.decode('utf-8').strip() + + # Should not have table borders or formatting characters + assert "+-" not in output, f"Expected no table borders, but got: {output}" + assert not output.startswith("|"), f"Expected no table pipes, but got: {output}" + + # Should have some output (the data) + assert len(output) > 0, f"Expected data output, but got empty: {output}" + + +@then('we don\'t see "{text}"') +def step_dont_see_text(context, text): + """Verify that specified text is NOT in the output.""" + output = context.cmd_output.decode('utf-8') + assert text not in output, f"Expected NOT to see '{text}' in output, but got: {output}" + + +@then('we see "{text}" in the output') +def step_see_text_in_output(context, text): + """Verify that specified text IS in the output.""" + output = context.cmd_output.decode('utf-8') + assert text in output, f"Expected to see '{text}' in output, but got: {output}" + + +@then("we see tab-separated values") +def step_see_tab_separated_values(context): + """Verify that output contains tab-separated values.""" + output = context.cmd_output.decode('utf-8').strip() + + # Should contain tabs + assert "\t" in output, f"Expected tab-separated values, but got: {output}" + + # Should not have table borders + assert "+-" not in output, f"Expected no table borders, but got: {output}" + assert "|" not in output, f"Expected no table pipes, but got: {output}" + + +@then("we see multiple data rows") +def step_see_multiple_data_rows(context): + """Verify that output contains multiple rows of data.""" + output = context.cmd_output.decode('utf-8').strip() + lines = output.split('\n') + + # Filter out empty lines + data_lines = [line for line in lines if line.strip()] + + # Should have multiple rows + assert len(data_lines) >= 3, f"Expected at least 3 data rows, but got {len(data_lines)}: {output}" + + # Should not have table formatting + assert "+-" not in output, f"Expected no table borders, but got: {output}" diff --git a/tests/features/tuples_only.feature b/tests/features/tuples_only.feature new file mode 100644 index 000000000..76812d3a2 --- /dev/null +++ b/tests/features/tuples_only.feature @@ -0,0 +1,48 @@ +Feature: run the cli with -t/--tuples-only option, + print rows only without status messages and timing + + Scenario: run pgcli with -t flag (default csv-noheader format) + When we run pgcli with "-t -c 'SELECT 1'" + then we see only the data rows + and we don't see "SELECT" + and we don't see "Time:" + and pgcli exits successfully + + Scenario: run pgcli with --tuples-only flag + When we run pgcli with "--tuples-only -c 'SELECT 1'" + then we see only the data rows + and we don't see "SELECT" + and we don't see "Time:" + and pgcli exits successfully + + Scenario: run pgcli with -t and minimal format + When we run pgcli with "-t minimal -c 'SELECT 1, 2'" + then we see only the data rows + and we don't see "SELECT" + and we don't see "Time:" + and pgcli exits successfully + + Scenario: run pgcli with -t and tsv_noheader format + When we run pgcli with "-t tsv_noheader -c 'SELECT 1, 2'" + then we see tab-separated values + and we don't see "SELECT" + and we don't see "Time:" + and pgcli exits successfully + + Scenario: run pgcli without -t flag (normal output) + When we run pgcli with "-c 'SELECT 1'" + then we see "SELECT" in the output + and we see "Time:" in the output + and pgcli exits successfully + + Scenario: run pgcli with -t and multiple rows + When we run pgcli with "-t -c 'SELECT generate_series(1, 3)'" + then we see multiple data rows + and we don't see "SELECT" + and we don't see "Time:" + and pgcli exits successfully + + Scenario: run pgcli with -t and special command + When we run pgcli with "-t -c '\\dt'" + then we see the command output + and pgcli exits successfully diff --git a/tests/test_main.py b/tests/test_main.py index 5cf1d09f8..697c96009 100644 --- a/tests/test_main.py +++ b/tests/test_main.py @@ -447,8 +447,10 @@ def test_missing_rc_dir(tmpdir): def test_quoted_db_uri(tmpdir): with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) - cli.connect_uri("postgres://bar%5E:%5Dfoo@baz.com/testdb%5B") - mock_connect.assert_called_with(database="testdb[", host="baz.com", user="bar^", passwd="]foo") + uri = "postgres://bar%5E:%5Dfoo@baz.com/testdb%5B" + cli.connect_uri(uri) + # connect_uri now passes the original URI as dsn for .pgpass support + mock_connect.assert_called_with(dsn=uri, database="testdb[", host="baz.com", user="bar^", passwd="]foo") def test_pg_service_file(tmpdir): @@ -507,10 +509,10 @@ def test_pg_service_file(tmpdir): def test_ssl_db_uri(tmpdir): with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) - cli.connect_uri( - "postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem" - ) + uri = "postgres://bar%5E:%5Dfoo@baz.com/testdb%5B?sslmode=verify-full&sslcert=m%79.pem&sslkey=my-key.pem&sslrootcert=c%61.pem" + cli.connect_uri(uri) mock_connect.assert_called_with( + dsn=uri, database="testdb[", host="baz.com", user="bar^", @@ -525,15 +527,18 @@ def test_ssl_db_uri(tmpdir): def test_port_db_uri(tmpdir): with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) - cli.connect_uri("postgres://bar:foo@baz.com:2543/testdb") - mock_connect.assert_called_with(database="testdb", host="baz.com", user="bar", passwd="foo", port="2543") + uri = "postgres://bar:foo@baz.com:2543/testdb" + cli.connect_uri(uri) + mock_connect.assert_called_with(dsn=uri, database="testdb", host="baz.com", user="bar", passwd="foo", port="2543") def test_multihost_db_uri(tmpdir): with mock.patch.object(PGCli, "connect") as mock_connect: cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) - cli.connect_uri("postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb") + uri = "postgres://bar:foo@baz1.com:2543,baz2.com:2543,baz3.com:2543/testdb" + cli.connect_uri(uri) mock_connect.assert_called_with( + dsn=uri, database="testdb", host="baz1.com,baz2.com,baz3.com", user="bar", @@ -546,8 +551,10 @@ def test_application_name_db_uri(tmpdir): with mock.patch.object(PGExecute, "__init__") as mock_pgexecute: mock_pgexecute.return_value = None cli = PGCli(pgclirc_file=str(tmpdir.join("rcfile"))) - cli.connect_uri("postgres://bar@baz.com/?application_name=cow") - mock_pgexecute.assert_called_with("bar", "bar", "", "baz.com", "", "", notify_callback, application_name="cow") + uri = "postgres://bar@baz.com/?application_name=cow" + cli.connect_uri(uri) + # connect_uri now passes the URI as dsn + mock_pgexecute.assert_called_with("bar", "bar", "", "baz.com", "", uri, notify_callback, application_name="cow") @pytest.mark.parametrize( @@ -595,3 +602,51 @@ def test_notifications(executor): with mock.patch("pgcli.main.click.secho") as mock_secho: run(executor, "notify chan1, 'testing2'") mock_secho.assert_not_called() + + +def test_force_destructive_flag(): + """Test that PGCli can be initialized with force_destructive flag.""" + cli = PGCli(force_destructive=True) + assert cli.force_destructive is True + + cli = PGCli(force_destructive=False) + assert cli.force_destructive is False + + cli = PGCli() + assert cli.force_destructive is False + + +@dbtest +def test_force_destructive_skips_confirmation(executor): + """Test that force_destructive=True skips confirmation for destructive commands.""" + cli = PGCli(pgexecute=executor, force_destructive=True) + cli.destructive_warning = ["drop", "alter"] + + # Mock confirm_destructive_query to ensure it's not called + with mock.patch("pgcli.main.confirm_destructive_query") as mock_confirm: + # Execute a destructive command + result = cli.execute_command("ALTER TABLE test_table ADD COLUMN test_col TEXT;") + + # Verify that confirm_destructive_query was NOT called + mock_confirm.assert_not_called() + + # Verify that the command was attempted (even if it fails due to missing table) + assert result is not None + + +@dbtest +def test_without_force_destructive_calls_confirmation(executor): + """Test that without force_destructive, confirmation is called for destructive commands.""" + cli = PGCli(pgexecute=executor, force_destructive=False) + cli.destructive_warning = ["drop", "alter"] + + # Mock confirm_destructive_query to return True (user confirms) + with mock.patch("pgcli.main.confirm_destructive_query", return_value=True) as mock_confirm: + # Execute a destructive command + result = cli.execute_command("ALTER TABLE test_table ADD COLUMN test_col TEXT;") + + # Verify that confirm_destructive_query WAS called + mock_confirm.assert_called_once() + + # Verify that the command was attempted + assert result is not None diff --git a/tests/test_ssh_tunnel.py b/tests/test_ssh_tunnel.py index c8670141b..a53c4f1e4 100644 --- a/tests/test_ssh_tunnel.py +++ b/tests/test_ssh_tunnel.py @@ -40,6 +40,9 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM "remote_bind_address": (db_params["host"], 5432), "ssh_address_or_host": (tunnel_url, 22), "logger": ANY, + "ssh_config_file": "~/.ssh/config", + "allow_agent": True, + "compression": False, } pgcli = PGCli(ssh_tunnel_url=tunnel_url) @@ -50,15 +53,19 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM mock_pgexecute.assert_called_once() call_args, call_kwargs = mock_pgexecute.call_args + # With SSH tunnel, host should be preserved for .pgpass lookup + # and hostaddr should be set to 127.0.0.1 for actual connection assert call_args == ( db_params["database"], db_params["user"], db_params["passwd"], - "127.0.0.1", + db_params["host"], # Original host preserved pgcli.ssh_tunnel.local_bind_ports[0], "", notify_callback, ) + # Verify hostaddr is passed in kwargs + assert call_kwargs.get("hostaddr") == "127.0.0.1" mock_ssh_tunnel_forwarder.reset_mock() mock_pgexecute.reset_mock() @@ -86,15 +93,19 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM mock_pgexecute.assert_called_once() call_args, call_kwargs = mock_pgexecute.call_args + # With SSH tunnel, host should be preserved for .pgpass lookup + # and hostaddr should be set to 127.0.0.1 for actual connection assert call_args == ( db_params["database"], db_params["user"], db_params["passwd"], - "127.0.0.1", + db_params["host"], # Original host preserved pgcli.ssh_tunnel.local_bind_ports[0], "", notify_callback, ) + # Verify hostaddr is passed in kwargs + assert call_kwargs.get("hostaddr") == "127.0.0.1" mock_ssh_tunnel_forwarder.reset_mock() mock_pgexecute.reset_mock() @@ -104,13 +115,17 @@ def test_ssh_tunnel(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicM pgcli = PGCli(ssh_tunnel_url=tunnel_url) pgcli.connect(dsn=dsn) - expected_dsn = f"user={db_params['user']} password={db_params['passwd']} host=127.0.0.1 port={pgcli.ssh_tunnel.local_bind_ports[0]}" - + # With SSH tunnel + DSN, host is preserved and hostaddr is added + # This allows .pgpass to work with the original hostname mock_ssh_tunnel_forwarder.assert_called_once_with(**expected_tunnel_params) mock_pgexecute.assert_called_once() call_args, call_kwargs = mock_pgexecute.call_args - assert expected_dsn in call_args + # The DSN should contain the original host, the tunnel port, and hostaddr + dsn_arg = call_args[5] # DSN is the 6th positional argument + assert f"host={db_params['host']}" in dsn_arg + assert f"hostaddr=127.0.0.1" in dsn_arg + assert f"port={pgcli.ssh_tunnel.local_bind_ports[0]}" in dsn_arg def test_cli_with_tunnel() -> None: @@ -174,3 +189,140 @@ def test_config(tmpdir: os.PathLike, mock_ssh_tunnel_forwarder: MagicMock, mock_ assert call_kwargs["ssh_address_or_host"] == (tunnel_host, tunnel_port) assert call_kwargs["ssh_username"] == tunnel_user assert call_kwargs["ssh_password"] == tunnel_passwd + + +def test_ssh_tunnel_with_uri(mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock) -> None: + """Test that connect_uri passes DSN for .pgpass compatibility""" + tunnel_url = "tunnel.host" + uri = "postgresql://testuser@db.example.com:5432/testdb" + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect_uri(uri) + + # Verify SSH tunnel was created + mock_ssh_tunnel_forwarder.assert_called_once() + mock_ssh_tunnel_forwarder.return_value.start.assert_called_once() + + # Verify PGExecute was called + mock_pgexecute.assert_called_once() + call_args, call_kwargs = mock_pgexecute.call_args + + # The DSN should be passed (6th positional argument) + dsn_arg = call_args[5] + assert dsn_arg # DSN should not be empty + assert "host=db.example.com" in dsn_arg + assert "hostaddr=127.0.0.1" in dsn_arg + assert f"port={pgcli.ssh_tunnel.local_bind_ports[0]}" in dsn_arg + assert "user=testuser" in dsn_arg + assert "dbname=testdb" in dsn_arg + + +def test_ssh_tunnel_preserves_original_host_for_pgpass( + mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock +) -> None: + """Test that original hostname is preserved for .pgpass lookup""" + tunnel_url = "tunnel.host" + original_host = "production-db.aws.amazonaws.com" + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(database="mydb", host=original_host, user="admin") + + mock_pgexecute.assert_called_once() + call_args, call_kwargs = mock_pgexecute.call_args + + # Host argument should be the original hostname, not 127.0.0.1 + assert call_args[3] == original_host + + # hostaddr should be 127.0.0.1 for actual connection + assert call_kwargs.get("hostaddr") == "127.0.0.1" + + +def test_ssh_tunnel_with_dsn_string( + mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock +) -> None: + """Test SSH tunnel with DSN connection string""" + tunnel_url = "tunnel.host" + dsn = "host=db.prod.com port=5432 dbname=myapp user=appuser" + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(dsn=dsn) + + mock_ssh_tunnel_forwarder.assert_called_once() + mock_pgexecute.assert_called_once() + + call_args, call_kwargs = mock_pgexecute.call_args + dsn_arg = call_args[5] + + # DSN should preserve original host and add hostaddr + assert "host=db.prod.com" in dsn_arg + assert "hostaddr=127.0.0.1" in dsn_arg + # Port should be changed to tunnel port + assert f"port={pgcli.ssh_tunnel.local_bind_ports[0]}" in dsn_arg + + +def test_no_ssh_tunnel_does_not_set_hostaddr(mock_pgexecute: MagicMock) -> None: + """Test that hostaddr is not set when SSH tunnel is not used""" + pgcli = PGCli() + pgcli.connect(database="mydb", host="localhost", user="user") + + mock_pgexecute.assert_called_once() + call_args, call_kwargs = mock_pgexecute.call_args + + # hostaddr should not be in kwargs when no SSH tunnel + assert "hostaddr" not in call_kwargs + + +def test_ssh_tunnel_with_port_in_dsn( + mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock +) -> None: + """Test that custom port in DSN is handled correctly with SSH tunnel""" + tunnel_url = "tunnel.host" + dsn = "postgresql://user@db.example.com:6543/testdb" + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect_uri(dsn) + + # Verify tunnel remote_bind_address uses the original port + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert call_kwargs["remote_bind_address"] == ("db.example.com", 6543) + + # Verify connection uses tunnel local port + mock_pgexecute.assert_called_once() + call_args, call_kwargs = mock_pgexecute.call_args + dsn_arg = call_args[5] + assert f"port={pgcli.ssh_tunnel.local_bind_ports[0]}" in dsn_arg + + +def test_ssh_tunnel_config_with_ssh_config_file( + mock_ssh_tunnel_forwarder: MagicMock, mock_pgexecute: MagicMock +) -> None: + """Test that SSH tunnel uses ssh_config_file parameter""" + tunnel_url = "tunnel.host" + + pgcli = PGCli(ssh_tunnel_url=tunnel_url) + pgcli.connect(database="db", host="remote.host", user="user") + + # Verify SSHTunnelForwarder was called with ssh_config_file + call_args, call_kwargs = mock_ssh_tunnel_forwarder.call_args + assert "ssh_config_file" in call_kwargs + assert call_kwargs["ssh_config_file"] == "~/.ssh/config" + assert call_kwargs["allow_agent"] is True + assert call_kwargs["compression"] is False + + +def test_connect_uri_without_ssh_tunnel(mock_pgexecute: MagicMock) -> None: + """Test that connect_uri works correctly without SSH tunnel""" + uri = "postgresql://testuser:testpass@localhost:5432/testdb" + + pgcli = PGCli() + pgcli.connect_uri(uri) + + mock_pgexecute.assert_called_once() + call_args, call_kwargs = mock_pgexecute.call_args + + # DSN should be passed + dsn_arg = call_args[5] + assert uri == dsn_arg + + # hostaddr should not be set without SSH tunnel + assert "hostaddr" not in call_kwargs