diff --git a/dfetch/project/superproject.py b/dfetch/project/superproject.py index 5b4afb84..bb73ae84 100644 --- a/dfetch/project/superproject.py +++ b/dfetch/project/superproject.py @@ -19,6 +19,7 @@ from dfetch.project.git import GitSubProject from dfetch.project.subproject import SubProject from dfetch.project.svn import SvnSubProject +from dfetch.util.util import resolve_absolute_path from dfetch.vcs.git import GitLocalRepo from dfetch.vcs.svn import SvnRepo @@ -40,10 +41,12 @@ def __init__(self) -> None: logger.debug(f"Using manifest {manifest_path}") self._manifest = parse(manifest_path) - self._root_directory = os.path.dirname(self._manifest.path) + self._root_directory = resolve_absolute_path( + os.path.dirname(self._manifest.path) + ) @property - def root_directory(self) -> str: + def root_directory(self) -> pathlib.Path: """Return the directory that contains the manifest file.""" return self._root_directory @@ -63,11 +66,12 @@ def get_sub_project(self, project: ProjectEntry) -> SubProject | None: def ignored_files(self, path: str) -> Sequence[str]: """Return a list of files that can be ignored in a given path.""" - if ( - os.path.commonprefix((pathlib.Path(path).resolve(), self.root_directory)) - != self.root_directory - ): - raise RuntimeError(f"{path} not in superproject {self.root_directory}!") + resolved_path = resolve_absolute_path(path) + + if not resolved_path.is_relative_to(self.root_directory): + raise RuntimeError( + f"{resolved_path} not in superproject {self.root_directory}!" + ) if GitLocalRepo(self.root_directory).is_git(): return GitLocalRepo.ignored_files(path) diff --git a/dfetch/util/util.py b/dfetch/util/util.py index 91194571..5ee2a450 100644 --- a/dfetch/util/util.py +++ b/dfetch/util/util.py @@ -63,14 +63,14 @@ def safe_rmtree(path: str) -> None: @contextmanager -def in_directory(path: str) -> Generator[str, None, None]: +def in_directory(path: Union[str, Path]) -> Generator[str, None, None]: """Work temporarily in a given directory.""" pwd = os.getcwd() if not os.path.isdir(path): path = os.path.dirname(path) os.chdir(path) try: - yield path + yield str(path) finally: os.chdir(pwd) @@ -159,3 +159,16 @@ def str_if_possible(data: list[str]) -> Union[str, list[str]]: if the list is empty, otherwise the original list. """ return "" if not data else data[0] if len(data) == 1 else data + + +def resolve_absolute_path(path: Union[str, Path]) -> Path: + """Return a guaranteed absolute Path, resolving symlinks. + + Args: + path: A string or Path to resolve. + + Notes: + - Uses os.path.realpath for reliable absolute paths across platforms. + - Handles Windows drive-relative paths and expands '~'. + """ + return Path(os.path.realpath(Path(path).expanduser())) diff --git a/dfetch/vcs/git.py b/dfetch/vcs/git.py index 82051e57..4e0af102 100644 --- a/dfetch/vcs/git.py +++ b/dfetch/vcs/git.py @@ -7,7 +7,7 @@ import tempfile from collections.abc import Generator, Sequence from pathlib import Path, PurePath -from typing import NamedTuple, Optional +from typing import NamedTuple, Optional, Union from dfetch.log import get_logger from dfetch.util.cmdline import SubprocessCommandError, run_on_cmdline @@ -234,9 +234,9 @@ class GitLocalRepo: METADATA_DIR = ".git" - def __init__(self, path: str = ".") -> None: + def __init__(self, path: Union[str, Path] = ".") -> None: """Create a local git repo.""" - self._path = path + self._path = str(path) def is_git(self) -> bool: """Check if is git.""" diff --git a/dfetch/vcs/svn.py b/dfetch/vcs/svn.py index 1f3396c0..7f3cf1c1 100644 --- a/dfetch/vcs/svn.py +++ b/dfetch/vcs/svn.py @@ -4,7 +4,7 @@ import pathlib import re from collections.abc import Sequence -from typing import NamedTuple, Optional +from typing import NamedTuple, Optional, Union from dfetch.log import get_logger from dfetch.util.cmdline import SubprocessCommandError, run_on_cmdline @@ -76,10 +76,10 @@ class SvnRepo: def __init__( self, - path: str = ".", + path: Union[str, pathlib.Path] = ".", ) -> None: """Create a svn repo.""" - self._path = path + self._path = str(path) def is_svn(self) -> bool: """Check if is SVN.""" diff --git a/tests/test_check.py b/tests/test_check.py index bf9080ef..73dda675 100644 --- a/tests/test_check.py +++ b/tests/test_check.py @@ -4,6 +4,7 @@ # flake8: noqa import argparse +from pathlib import Path from unittest.mock import Mock, patch import pytest @@ -30,7 +31,7 @@ def test_check(name, projects): fake_superproject = Mock() fake_superproject.manifest = mock_manifest(projects) - fake_superproject.root_directory = "/tmp" + fake_superproject.root_directory = Path("/tmp") with patch("dfetch.commands.check.SuperProject", return_value=fake_superproject): with patch( diff --git a/tests/test_report.py b/tests/test_report.py index 9ac2cebc..86ea4d69 100644 --- a/tests/test_report.py +++ b/tests/test_report.py @@ -4,6 +4,7 @@ # flake8: noqa import argparse +from pathlib import Path from unittest.mock import Mock, patch import pytest @@ -30,7 +31,7 @@ def test_report(name, projects): fake_superproject = Mock() fake_superproject.manifest = mock_manifest(projects) - fake_superproject.root_directory = "/tmp" + fake_superproject.root_directory = Path("/tmp") with patch("dfetch.commands.report.SuperProject", return_value=fake_superproject): with patch("dfetch.log.DLogger.print_info_line") as mocked_print_info_line: diff --git a/tests/test_update.py b/tests/test_update.py index 24b9169e..c33014cd 100644 --- a/tests/test_update.py +++ b/tests/test_update.py @@ -4,6 +4,7 @@ # flake8: noqa import argparse +from pathlib import Path from unittest.mock import Mock, patch import pytest @@ -30,7 +31,7 @@ def test_update(name, projects): fake_superproject = Mock() fake_superproject.manifest = mock_manifest(projects) - fake_superproject.root_directory = "/tmp" + fake_superproject.root_directory = Path("/tmp") with patch("dfetch.commands.update.SuperProject", return_value=fake_superproject): with patch( @@ -53,7 +54,7 @@ def test_forced_update(): fake_superproject = Mock() fake_superproject.manifest = mock_manifest([{"name": "some_project"}]) - fake_superproject.root_directory = "/tmp" + fake_superproject.root_directory = Path("/tmp") fake_superproject.ignored_files.return_value = [] with patch("dfetch.commands.update.SuperProject", return_value=fake_superproject):