diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 2c31e1d8..df37a3cf 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -24,6 +24,8 @@ Release 0.11.0 (unreleased) * Add more tests and documentation for patching (#888) * Restrict ``src`` to string only in schema (#888) * Don't consider ignored files for determining local changes (#350) +* Avoid waiting for user input in ``git`` & ``svn`` commands (#570) +* Extend git ssh command to run in BatchMode (#570) Release 0.10.0 (released 2025-03-12) ==================================== diff --git a/dfetch/project/svn.py b/dfetch/project/svn.py index b56ef289..09f8824e 100644 --- a/dfetch/project/svn.py +++ b/dfetch/project/svn.py @@ -52,6 +52,7 @@ def externals() -> list[External]: logger, [ "svn", + "--non-interactive", "propget", "svn:externals", "-R", @@ -130,13 +131,13 @@ def _split_url(url: str, repo_root: str) -> tuple[str, str, str, str]: def check(self) -> bool: """Check if is SVN.""" try: - run_on_cmdline(logger, f"svn info {self.remote} --non-interactive") + run_on_cmdline(logger, ["svn", "info", self.remote, "--non-interactive"]) return True except SubprocessCommandError as exc: - if exc.stdout.startswith("svn: E170013"): + if exc.stderr.startswith("svn: E170013"): raise RuntimeError( f">>>{exc.cmd}<<< failed!\n" - + f"'{self.remote}' is not a valid URL or unreachable:\n{exc.stderr or exc.stdout}" + + f"'{self.remote}' is not a valid URL or unreachable:\n{exc.stdout or exc.stderr}" ) from exc return False except RuntimeError: @@ -147,7 +148,7 @@ def check_path(path: str = ".") -> bool: """Check if is SVN.""" try: with in_directory(path): - run_on_cmdline(logger, "svn info --non-interactive") + run_on_cmdline(logger, ["svn", "info", "--non-interactive"]) return True except (SubprocessCommandError, RuntimeError): return False @@ -171,7 +172,9 @@ def _does_revision_exist(self, revision: str) -> bool: def _list_of_tags(self) -> list[str]: """Get list of all available tags.""" - result = run_on_cmdline(logger, f"svn ls --non-interactive {self.remote}/tags") + result = run_on_cmdline( + logger, ["svn", "ls", "--non-interactive", f"{self.remote}/tags"] + ) return [ str(tag).strip("/\r") for tag in result.stdout.decode().split("\n") if tag ] @@ -180,7 +183,7 @@ def _list_of_tags(self) -> list[str]: def list_tool_info() -> None: """Print out version information.""" try: - result = run_on_cmdline(logger, "svn --version") + result = run_on_cmdline(logger, ["svn", "--version", "--non-interactive"]) except RuntimeError as exc: logger.debug( f"Something went wrong trying to get the version of svn: {exc}" @@ -304,7 +307,9 @@ def _export(url: str, rev: str = "", dst: str = ".") -> None: def _files_in_path(url_path: str) -> list[str]: return [ str(line) - for line in run_on_cmdline(logger, f"svn list --non-interactive {url_path}") + for line in run_on_cmdline( + logger, ["svn", "list", "--non-interactive", url_path] + ) .stdout.decode() .splitlines() ] @@ -322,13 +327,13 @@ def _license_files(url_path: str) -> list[str]: def _get_info_from_target(target: str = "") -> dict[str, str]: try: result = run_on_cmdline( - logger, f"svn info --non-interactive {target.strip()}" + logger, ["svn", "info", "--non-interactive", target.strip()] ).stdout.decode() except SubprocessCommandError as exc: - if exc.stdout.startswith("svn: E170013"): + if exc.stderr.startswith("svn: E170013"): raise RuntimeError( f">>>{exc.cmd}<<< failed!\n" - + f"'{target.strip()}' is not a valid URL or unreachable:\n{exc.stdout}" + + f"'{target.strip()}' is not a valid URL or unreachable:\n{exc.stderr or exc.stdout}" ) from exc raise @@ -347,7 +352,7 @@ def _get_last_changed_revision(target: str) -> str: if os.path.isdir(target): last_digits = re.compile(r"(?P\d+)(?!.*\d)") version = run_on_cmdline( - logger, f"svnversion {target.strip()}" + logger, ["svnversion", target.strip()] ).stdout.decode() parsed_version = last_digits.search(version) @@ -358,7 +363,14 @@ def _get_last_changed_revision(target: str) -> str: return str( run_on_cmdline( logger, - f"svn info --non-interactive --show-item last-changed-revision {target.strip()}", + [ + "svn", + "info", + "--non-interactive", + "--show-item", + "last-changed-revision", + target.strip(), + ], ) .stdout.decode() .strip() @@ -415,7 +427,7 @@ def _untracked_files(path: str, ignore: Sequence[str]) -> list[str]: result = ( run_on_cmdline( logger, - ["svn", "status", path], + ["svn", "status", "--non-interactive", path], ) .stdout.decode() .splitlines() @@ -441,7 +453,7 @@ def ignored_files(path: str) -> Sequence[str]: result = ( run_on_cmdline( logger, - ["svn", "status", "--no-ignore", "."], + ["svn", "status", "--non-interactive", "--no-ignore", "."], ) .stdout.decode() .splitlines() diff --git a/dfetch/util/cmdline.py b/dfetch/util/cmdline.py index c302e77d..d0de8245 100644 --- a/dfetch/util/cmdline.py +++ b/dfetch/util/cmdline.py @@ -3,7 +3,8 @@ import logging import os import subprocess # nosec -from typing import Any, Optional, Union # pylint: disable=unused-import +from collections.abc import Mapping +from typing import Any, Optional class SubprocessCommandError(Exception): @@ -24,8 +25,8 @@ def __init__( cmd_str: str = " ".join(cmd or []) self._message = f">>>{cmd_str}<<< returned {returncode}:{os.linesep}{stderr}" self.cmd = cmd_str - self.stderr = stdout - self.stdout = stderr + self.stdout = stdout + self.stderr = stderr self.returncode = returncode super().__init__(self._message) @@ -36,16 +37,15 @@ def message(self) -> str: def run_on_cmdline( - logger: logging.Logger, cmd: Union[str, list[str]] + logger: logging.Logger, + cmd: list[str], + env: Optional[Mapping[str, str]] = None, ) -> "subprocess.CompletedProcess[Any]": """Run a command and log the output, and raise if something goes wrong.""" logger.debug(f"Running {cmd}") - if not isinstance(cmd, list): - cmd = cmd.split(" ") - try: - proc = subprocess.run(cmd, capture_output=True, check=True) # nosec + proc = subprocess.run(cmd, env=env, capture_output=True, check=True) # nosec except subprocess.CalledProcessError as exc: raise SubprocessCommandError( exc.cmd, @@ -54,8 +54,7 @@ def run_on_cmdline( exc.returncode, ) from exc except FileNotFoundError as exc: - cmd = cmd[0] - raise RuntimeError(f"{cmd} not available on system, please install") from exc + raise RuntimeError(f"{cmd[0]} not available on system, please install") from exc stdout, stderr = proc.stdout, proc.stderr diff --git a/dfetch/vcs/git.py b/dfetch/vcs/git.py index 94e4bfb3..82051e57 100644 --- a/dfetch/vcs/git.py +++ b/dfetch/vcs/git.py @@ -1,5 +1,6 @@ """Git specific implementation.""" +import functools import os import re import shutil @@ -30,11 +31,57 @@ class Submodule(NamedTuple): def get_git_version() -> tuple[str, str]: """Get the name and version of git.""" - result = run_on_cmdline(logger, "git --version") + result = run_on_cmdline(logger, ["git", "--version"]) tool, version = result.stdout.decode().strip().split("version", maxsplit=1) return (str(tool), str(version)) +def _build_git_ssh_command() -> str: + """Returns a safe SSH command string for Git that enforces non-interactive mode. + + Respects existing GIT_SSH_COMMAND and git core.sshCommand. + """ + ssh_cmd = os.environ.get("GIT_SSH_COMMAND") + + if not ssh_cmd: + + try: + result = run_on_cmdline( + logger, + ["git", "config", "--get", "core.sshCommand"], + ) + ssh_cmd = result.stdout.decode().strip() + + except SubprocessCommandError: + ssh_cmd = None + + if not ssh_cmd: + ssh_cmd = "ssh" + + if "BatchMode=" not in ssh_cmd: + ssh_cmd += " -o BatchMode=yes" + else: + logger.debug(f'BatchMode already configured in "{ssh_cmd}"') + + return ssh_cmd + + +# As a cli tool, we can safely assume this remains stable during the runtime, caching for speed is better +@functools.lru_cache +def _extend_env_for_non_interactive_mode() -> dict[str, str]: + """Extend the environment vars for git running in non-interactive mode. + + See https://serverfault.com/a/1054253 for background info + """ + env = os.environ.copy() + env["GIT_TERMINAL_PROMPT"] = "0" + env["GIT_SSH_COMMAND"] = _build_git_ssh_command() + + # https://stackoverflow.com/questions/37182847/how-do-i-disable-git-credential-manager-for-windows#answer-45513654 + env["GCM_INTERACTIVE"] = "never" + return env + + class GitRemote: """A remote git repo.""" @@ -48,10 +95,14 @@ def is_git(self) -> bool: return True try: - run_on_cmdline(logger, f"git ls-remote --heads {self._remote}") + run_on_cmdline( + logger, + cmd=["git", "ls-remote", "--heads", self._remote], + env=_extend_env_for_non_interactive_mode(), + ) return True except SubprocessCommandError as exc: - if exc.returncode == 128 and "Could not resolve host" in exc.stdout: + if exc.returncode == 128 and "Could not resolve host" in exc.stderr: raise RuntimeError( f">>>{exc.cmd}<<< failed!\n" + f"'{self._remote}' is not a valid URL or unreachable:\n{exc.stderr or exc.stdout}" @@ -82,7 +133,9 @@ def get_default_branch(self) -> str: """Try to get the default branch or fallback to master.""" try: result = run_on_cmdline( - logger, f"git ls-remote --symref {self._remote} HEAD" + logger, + cmd=["git", "ls-remote", "--symref", self._remote, "HEAD"], + env=_extend_env_for_non_interactive_mode(), ).stdout.decode() except SubprocessCommandError: logger.debug( @@ -101,7 +154,9 @@ def get_default_branch(self) -> str: @staticmethod def _ls_remote(remote: str) -> dict[str, str]: result = run_on_cmdline( - logger, f"git ls-remote --heads --tags {remote}" + logger, + cmd=["git", "ls-remote", "--heads", "--tags", remote], + env=_extend_env_for_non_interactive_mode(), ).stdout.decode() info: dict[str, str] = {} @@ -156,12 +211,14 @@ def check_version_exists( temp_dir = tempfile.mkdtemp() exists = False with in_directory(temp_dir): - run_on_cmdline(logger, "git init") - run_on_cmdline(logger, f"git remote add origin {self._remote}") - run_on_cmdline(logger, "git checkout -b dfetch-local-branch") + run_on_cmdline(logger, ["git", "init"]) + run_on_cmdline(logger, ["git", "remote", "add", "origin", self._remote]) + run_on_cmdline(logger, ["git", "checkout", "-b", "dfetch-local-branch"]) try: run_on_cmdline( - logger, f"git fetch --dry-run --depth 1 origin {version}" + logger, + ["git", "fetch", "--dry-run", "--depth", "1", "origin", version], + env=_extend_env_for_non_interactive_mode(), ) exists = True except SubprocessCommandError as exc: @@ -185,7 +242,10 @@ def is_git(self) -> bool: """Check if is git.""" try: with in_directory(self._path): - run_on_cmdline(logger, "git status") + run_on_cmdline( + logger, + ["git", "status"], + ) return True except (SubprocessCommandError, RuntimeError): return False @@ -209,12 +269,12 @@ def checkout_version( # pylint: disable=too-many-arguments ignore (Optional[Sequence[str]]): Optional sequence of glob patterns to ignore (relative to src) """ with in_directory(self._path): - run_on_cmdline(logger, "git init") - run_on_cmdline(logger, f"git remote add origin {remote}") - run_on_cmdline(logger, "git checkout -b dfetch-local-branch") + run_on_cmdline(logger, ["git", "init"]) + run_on_cmdline(logger, ["git", "remote", "add", "origin", remote]) + run_on_cmdline(logger, ["git", "checkout", "-b", "dfetch-local-branch"]) if src or ignore: - run_on_cmdline(logger, "git config core.sparsecheckout true") + run_on_cmdline(logger, ["git", "config", "core.sparsecheckout", "true"]) with open( ".git/info/sparse-checkout", "a", encoding="utf-8" ) as sparse_checkout_file: @@ -228,11 +288,17 @@ def checkout_version( # pylint: disable=too-many-arguments sparse_checkout_file.write("\n") sparse_checkout_file.write("\n".join(ignore_abs_paths)) - run_on_cmdline(logger, f"git fetch --depth 1 origin {version}") - run_on_cmdline(logger, "git reset --hard FETCH_HEAD") + run_on_cmdline( + logger, + ["git", "fetch", "--depth", "1", "origin", version], + env=_extend_env_for_non_interactive_mode(), + ) + run_on_cmdline(logger, ["git", "reset", "--hard", "FETCH_HEAD"]) current_sha = ( - run_on_cmdline(logger, "git rev-parse HEAD").stdout.decode().strip() + run_on_cmdline(logger, ["git", "rev-parse", "HEAD"]) + .stdout.decode() + .strip() ) if src: @@ -305,10 +371,7 @@ def get_current_hash(self) -> str: def get_remote_url() -> str: """Get the url of the remote origin.""" try: - result = run_on_cmdline( - logger, - ["git", "remote", "get-url", "origin"], - ) + result = run_on_cmdline(logger, ["git", "remote", "get-url", "origin"]) decoded_result = str(result.stdout.decode()) except SubprocessCommandError: decoded_result = "" diff --git a/features/check-git-repo.feature b/features/check-git-repo.feature index c6bc68e3..58695ef7 100644 --- a/features/check-git-repo.feature +++ b/features/check-git-repo.feature @@ -205,3 +205,37 @@ Feature: Checking dependencies from a git repository SomeProjectNonExistentBranch: wanted (i-dont-exist), but not available at the upstream. SomeProjectNonExistentRevision: wanted (0123112321234123512361236123712381239123), but not available at the upstream. """ + + Scenario: Credentials required for remote + Given the manifest 'dfetch.yaml' + """ + manifest: + version: '0.0' + + projects: + - name: private-repo + url: https://github.com/dfetch-org/test-repo-private.git + """ + When I run "dfetch check" + Then the output starts with: + """ + Dfetch (0.10.0) + >>>git ls-remote --heads --tags https://github.com/dfetch-org/test-repo-private.git<<< returned 128: + """ + + Scenario: SSH issues + Given the manifest 'dfetch.yaml' + """ + manifest: + version: '0.0' + + projects: + - name: private-repo + url: git@github.com:dfetch-org/test-repo-private.git + """ + When I run "dfetch check" + Then the output starts with: + """ + Dfetch (0.10.0) + >>>git ls-remote --heads --tags git@github.com:dfetch-org/test-repo-private.git<<< returned 128: + """ diff --git a/features/steps/generic_steps.py b/features/steps/generic_steps.py index 40d940ab..14763a56 100644 --- a/features/steps/generic_steps.py +++ b/features/steps/generic_steps.py @@ -154,6 +154,48 @@ def list_dir(path): return result +def check_output(context, line_count=None): + """Check command output against expected text. + + Args: + context: Behave context with cmd_output and expected text + line_count: If set, compare only the first N lines of actual output + """ + expected_text = multisub( + patterns=[ + (git_hash, r"\1[commit hash]\2"), + (timestamp, "[timestamp]"), + (dfetch_title, ""), + (svn_error, "svn: EXXXXXX: "), + ], + text=context.text, + ) + + actual_text = multisub( + patterns=[ + (git_hash, r"\1[commit hash]\2"), + (timestamp, "[timestamp]"), + (ansi_escape, ""), + ( + re.compile(f"file:///{remote_server_path(context)}"), + "some-remote-server", + ), + (svn_error, "svn: EXXXXXX: "), + ], + text=context.cmd_output, + ) + + actual_lines = actual_text.splitlines()[:line_count] + diff = difflib.ndiff(actual_lines, expected_text.splitlines()) + + diffs = [x for x in diff if x[0] in ("+", "-")] + if diffs: + comp = "\n".join(diffs) + print(actual_text) + print(comp) + assert False, "Output not as expected!" + + @given('"{old}" is replaced with "{new}" in "{path}"') def step_impl(_, old: str, new: str, path: str): replace_in_file(path, old, new) @@ -247,40 +289,14 @@ def multisub(patterns: List[Tuple[Pattern[str], str]], text: str) -> str: return text -@then("the output shows") +@then("the output starts with:") def step_impl(context): - expected_text = multisub( - patterns=[ - (git_hash, r"\1[commit hash]\2"), - (timestamp, "[timestamp]"), - (dfetch_title, ""), - (svn_error, "svn: EXXXXXX: "), - ], - text=context.text, - ) + check_output(context, line_count=len(context.text.splitlines())) - actual_text = multisub( - patterns=[ - (git_hash, r"\1[commit hash]\2"), - (timestamp, "[timestamp]"), - (ansi_escape, ""), - ( - re.compile(f"file:///{remote_server_path(context)}"), - "some-remote-server", - ), - (svn_error, "svn: EXXXXXX: "), - ], - text=context.cmd_output, - ) - - diff = difflib.ndiff(actual_text.splitlines(), expected_text.splitlines()) - diffs = [x for x in diff if x[0] in ("+", "-")] - if diffs: - comp = "\n".join(diffs) - print(actual_text) - print(comp) - assert False, "Output not as expected!" +@then("the output shows") +def step_impl(context): + check_output(context) @then("the following projects are fetched") diff --git a/tests/test_git_vcs.py b/tests/test_git_vcs.py index 3bf65b11..0089dcd4 100644 --- a/tests/test_git_vcs.py +++ b/tests/test_git_vcs.py @@ -3,23 +3,33 @@ # mypy: ignore-errors # flake8: noqa -from unittest.mock import patch +import os +from subprocess import CompletedProcess +from unittest.mock import Mock, patch import pytest from dfetch.util.cmdline import SubprocessCommandError -from dfetch.vcs.git import GitLocalRepo, GitRemote +from dfetch.vcs.git import ( + GitLocalRepo, + GitRemote, + _build_git_ssh_command, +) @pytest.mark.parametrize( "name, cmd_result, expectation", [ - ("git repo", ["Yep!"], True), + ("git repo", [CompletedProcess(args=[], returncode=0, stdout="Yep!")], True), ("not a git repo", [SubprocessCommandError()], False), ("no git", [RuntimeError()], False), + ("somewhere.git", [], True), ], ) def test_remote_check(name, cmd_result, expectation): + + os.environ["GIT_SSH_COMMAND"] = "ssh" # prevents additional subprocess call + with patch("dfetch.vcs.git.run_on_cmdline") as run_on_cmdline_mock: run_on_cmdline_mock.side_effect = cmd_result @@ -104,3 +114,49 @@ def test_ls_remote(): } assert info == expected + + +@pytest.mark.parametrize( + "name, env_ssh, git_config_ssh, expected", + [ + ( + "env var present", + "ssh -i keyfile", + None, + "ssh -i keyfile -o BatchMode=yes", + ), + ( + "git config", + None, + "ssh -F configfile", + "ssh -F configfile -o BatchMode=yes", + ), + ("no env or git config", None, None, "ssh -o BatchMode=yes"), + ( + "env with batchmode", + "ssh -o BatchMode=yes", + None, + "ssh -o BatchMode=yes", + ), + ], +) +def test_build_git_ssh_command(name, env_ssh, git_config_ssh, expected): + + with patch.dict( + os.environ, {"GIT_SSH_COMMAND": env_ssh} if env_ssh else {}, clear=True + ): + mock_run_git_config = Mock() + if git_config_ssh is not None: + mock_run_git_config.return_value.stdout = git_config_ssh.encode() + else: + mock_run_git_config.side_effect = SubprocessCommandError() + + with patch("dfetch.vcs.git.run_on_cmdline", mock_run_git_config): + with patch("dfetch.vcs.git.logger") as mock_logger: + result = _build_git_ssh_command() + assert result == expected + + if "BatchMode=" in (env_ssh or git_config_ssh or ""): + mock_logger.debug.assert_called_once() + else: + mock_logger.debug.assert_not_called()