diff --git a/pygit2/_libgit2/ffi.pyi b/pygit2/_libgit2/ffi.pyi index b2f6a8f1..a911985f 100644 --- a/pygit2/_libgit2/ffi.pyi +++ b/pygit2/_libgit2/ffi.pyi @@ -42,6 +42,11 @@ class _Pointer(Generic[T]): @overload def __getitem__(self, item: slice[None, None, None]) -> bytes: ... +class ArrayC(Generic[T]): + # incomplete! + # def _len(self, ?) -> ?: ... + pass + class GitTimeC: # incomplete time: int @@ -132,13 +137,13 @@ class GitDescribeFormatOptionsC: version: int abbreviated_size: int always_use_long_format: int - dirty_suffix: char_pointer + dirty_suffix: ArrayC[char] class GitDescribeOptionsC: version: int max_candidates_tags: int describe_strategy: int - pattern: char_pointer + pattern: ArrayC[char] only_follow_first_parent: int show_commit_oid_as_fallback: int @@ -148,6 +153,11 @@ class GitDescribeResultC: class GitIndexC: pass +class GitIndexEntryC: + # incomplete? + mode: int + path: ArrayC[char] + class GitMergeFileResultC: pass @@ -158,11 +168,13 @@ class GitStashSaveOptionsC: version: int flags: int stasher: GitSignatureC - message: char_pointer + message: ArrayC[char] paths: GitStrrayC class GitStrrayC: - pass + # incomplete? + strings: NULL_TYPE | ArrayC[char] + count: int class GitTreeC: pass @@ -171,11 +183,11 @@ class GitRepositoryInitOptionsC: version: int flags: int mode: int - workdir_path: char_pointer - description: char_pointer - template_path: char_pointer - initial_head: char_pointer - origin_url: char_pointer + workdir_path: ArrayC[char] + description: ArrayC[char] + template_path: ArrayC[char] + initial_head: ArrayC[char] + origin_url: ArrayC[char] class GitCloneOptionsC: pass @@ -229,6 +241,10 @@ def new(a: Literal['git_attr_options *']) -> GitAttrOptionsC: ... @overload def new(a: Literal['git_buf *']) -> GitBufC: ... @overload +def new(a: Literal['char *'], b: bytes) -> char_pointer: ... +@overload +def new(a: Literal['char *[]'], b: list[char_pointer]) -> ArrayC[char_pointer]: ... +@overload def new(a: Literal['git_checkout_options *']) -> GitCheckoutOptionsC: ... @overload def new(a: Literal['git_commit **']) -> _Pointer[GitCommitC]: ... @@ -251,6 +267,8 @@ def new(a: Literal['struct git_reference **']) -> _Pointer[GitReferenceC]: ... @overload def new(a: Literal['git_index **']) -> _Pointer[GitIndexC]: ... @overload +def new(a: Literal['git_index_entry *']) -> GitIndexEntryC: ... +@overload def new(a: Literal['git_merge_file_result *']) -> GitMergeFileResultC: ... @overload def new(a: Literal['git_object *']) -> GitObjectC: ... @@ -263,13 +281,15 @@ def new(a: Literal['git_signature **']) -> _Pointer[GitSignatureC]: ... @overload def new(a: Literal['git_stash_save_options *']) -> GitStashSaveOptionsC: ... @overload +def new(a: Literal['git_strarray *']) -> GitStrrayC: ... +@overload def new(a: Literal['git_tree **']) -> _Pointer[GitTreeC]: ... @overload def new(a: Literal['git_buf *'], b: tuple[NULL_TYPE, Literal[0]]) -> GitBufC: ... @overload def new(a: Literal['char **']) -> _Pointer[char_pointer]: ... @overload -def new(a: Literal['char[]', 'char []'], b: bytes | NULL_TYPE) -> char_pointer: ... +def new(a: Literal['char[]', 'char []'], b: bytes | NULL_TYPE) -> ArrayC[char]: ... def addressof(a: object, attribute: str) -> _Pointer[object]: ... class buffer(bytes): diff --git a/pygit2/_pygit2.pyi b/pygit2/_pygit2.pyi index 6fa853f0..acee4984 100644 --- a/pygit2/_pygit2.pyi +++ b/pygit2/_pygit2.pyi @@ -19,6 +19,7 @@ from typing import ( from . import Index from ._libgit2.ffi import ( GitCommitC, + GitMergeOptionsC, GitObjectC, GitProxyOptionsC, GitRepositoryC, @@ -36,12 +37,16 @@ from .enums import ( BranchType, CheckoutStrategy, DeltaStatus, + DescribeStrategy, DiffFind, DiffFlag, DiffOption, DiffStatsFormat, FileMode, MergeAnalysis, + MergeFavor, + MergeFileFlag, + MergeFlag, MergePreference, ObjectType, Option, @@ -51,6 +56,7 @@ from .enums import ( ResetMode, SortMode, ) +from .filter import Filter from .remotes import Remote from .repository import BaseRepository from .submodules import SubmoduleCollection @@ -457,6 +463,7 @@ class Diff: patch: str | None patchid: Oid stats: DiffStats + text: str def find_similar( self, flags: DiffFind = DiffFind.FIND_BY_CONFIG, @@ -520,6 +527,7 @@ class DiffStats: class FilterSource: # probably incomplete + repo: object pass class GitError(Exception): ... @@ -529,9 +537,9 @@ class Mailmap: def __init__(self, *args) -> None: ... def add_entry( self, - real_name: str = ..., - real_email: str = ..., - replace_name: str = ..., + real_name: str | None = ..., + real_email: str | None = ..., + replace_name: str | None = ..., replace_email: str = ..., ) -> None: ... @staticmethod @@ -719,6 +727,7 @@ class Branches: class Repository: _pointer: GitRepositoryC + _repo: GitRepositoryC default_signature: Signature head: Reference head_is_detached: bool @@ -784,7 +793,7 @@ class Repository: def compress_references(self) -> None: ... @property def config(self) -> Config: ... - def create_blob(self, data: bytes) -> Oid: ... + def create_blob(self, data: str | bytes) -> Oid: ... def create_blob_fromdisk(self, path: str) -> Oid: ... def create_blob_fromiobase(self, iobase: IOBase) -> Oid: ... def create_blob_fromworkdir(self, path: str | Path) -> Oid: ... @@ -834,14 +843,26 @@ class Repository: ) -> Oid: ... def diff( self, - a: None | str | Reference = None, - b: None | str | Reference = None, + a: None | str | bytes | Oid | Reference = None, + b: None | str | bytes | Oid | Reference = None, cached: bool = False, flags: DiffOption = DiffOption.NORMAL, context_lines: int = 3, interhunk_lines: int = 0, ) -> Diff: ... def descendant_of(self, oid1: _OidArg, oid2: _OidArg) -> bool: ... + def describe( + self, + committish: str | Reference | Commit | None = None, + max_candidates_tags: int | None = None, + describe_strategy: DescribeStrategy = DescribeStrategy.DEFAULT, + pattern: str | None = None, + only_follow_first_parent: bool | None = None, + show_commit_oid_as_fallback: bool | None = None, + abbreviated_size: object | None = None, + always_use_long_format: bool | None = None, + dirty_suffix: str | None = None, + ) -> str: ... def expand_id(self, hex: str) -> Oid: ... def free(self) -> None: ... def get(self, key: _OidArg, default: Optional[Commit] = None) -> None | Object: ... @@ -867,12 +888,40 @@ class Repository: def lookup_reference(self, name: str) -> Reference: ... def lookup_reference_dwim(self, name: str) -> Reference: ... def lookup_worktree(self, name: str) -> Worktree: ... + def merge( + self, + source: Reference | Commit | Oid | str, + favor: MergeFavor = MergeFavor.NORMAL, + flags: MergeFlag = MergeFlag.FIND_RENAMES, + file_flags: MergeFileFlag = MergeFileFlag.DEFAULT, + ) -> None: ... def merge_analysis( self, their_head: _OidArg, our_ref: str = 'HEAD' ) -> tuple[MergeAnalysis, MergePreference]: ... def merge_base(self, oid1: _OidArg, oid2: _OidArg) -> Oid: ... def merge_base_many(self, oids: list[_OidArg]) -> Oid: ... def merge_base_octopus(self, oids: list[_OidArg]) -> Oid: ... + def merge_commits( + self, + ours: str | Oid | Commit, + theirs: str | Oid | Commit, + favor: MergeFavor = MergeFavor.NORMAL, + flags: MergeFlag = MergeFlag.FIND_RENAMES, + file_flags: MergeFileFlag = MergeFileFlag.DEFAULT, + ) -> Index: ... + @staticmethod + def _merge_options( + favor: int | MergeFavor, flags: int | MergeFlag, file_flags: int | MergeFileFlag + ) -> GitMergeOptionsC: ... + def merge_trees( + self, + ancestor: str | Oid | Tree, + ours: str | Oid | Tree, + theirs: str | Oid | Tree, + favor: MergeFavor = MergeFavor.NORMAL, + flags: MergeFlag = MergeFlag.FIND_RENAMES, + file_flags: MergeFileFlag = MergeFileFlag.DEFAULT, + ) -> Index: ... @property def message(self) -> str: ... def notes(self) -> Iterator[Note]: ... @@ -881,6 +930,9 @@ class Repository: self, flag: BranchType = BranchType.LOCAL ) -> list[bytes]: ... def raw_listall_references(self) -> list[bytes]: ... + @property + def raw_message(self) -> bytes: ... + def remove_message(self) -> None: ... def references_iterator_init(self) -> Iterator[Reference]: ... def references_iterator_next( self, @@ -1022,5 +1074,7 @@ def option(opt: Option, *args) -> None: ... def reference_is_valid_name(refname: str) -> bool: ... def tree_entry_cmp(a: Object, b: Object) -> int: ... def _cache_enums() -> None: ... +def filter_register(name: str, filter: type[Filter]) -> None: ... +def filter_unregister(name: str) -> None: ... _OidArg = str | Oid diff --git a/pygit2/filter.py b/pygit2/filter.py index 00c65184..abf3eb35 100644 --- a/pygit2/filter.py +++ b/pygit2/filter.py @@ -58,7 +58,7 @@ class Filter: def nattrs(cls) -> int: return len(cls.attributes.split()) - def check(self, src: FilterSource, attr_values: List[Optional[str]]): + def check(self, src: FilterSource, attr_values: List[Optional[str]]) -> None: """ Check whether this filter should be applied to the given source. @@ -77,7 +77,7 @@ def check(self, src: FilterSource, attr_values: List[Optional[str]]): def write( self, data: bytes, src: FilterSource, write_next: Callable[[bytes], None] - ): + ) -> None: """ Write input `data` to this filter. @@ -95,7 +95,7 @@ def write( """ write_next(data) - def close(self, write_next: Callable[[bytes], None]): + def close(self, write_next: Callable[[bytes], None]) -> None: """ Close this filter. diff --git a/pygit2/index.py b/pygit2/index.py index e6ae6593..1b52b797 100644 --- a/pygit2/index.py +++ b/pygit2/index.py @@ -26,9 +26,10 @@ import typing import warnings from dataclasses import dataclass +from os import PathLike # Import from pygit2 -from ._pygit2 import Diff, Oid, Tree +from ._pygit2 import Diff, Oid, Repository, Tree from .enums import DiffOption, FileMode from .errors import check_error from .ffi import C, ffi @@ -41,7 +42,7 @@ class Index: # a proper implementation in some places: e.g. checking the index type # from C code (see Tree_diff_to_index) - def __init__(self, path: str | None = None) -> None: + def __init__(self, path: str | PathLike[str] | None = None) -> None: """Create a new Index If path is supplied, the read and write methods will use that path @@ -116,16 +117,16 @@ def read(self, force=True): err = C.git_index_read(self._index, force) check_error(err, io=True) - def write(self): + def write(self) -> None: """Write the contents of the Index to disk.""" err = C.git_index_write(self._index) check_error(err, io=True) - def clear(self): + def clear(self) -> None: err = C.git_index_clear(self._index) check_error(err) - def read_tree(self, tree): + def read_tree(self, tree: Oid | Tree | str) -> None: """Replace the contents of the Index with those of the given tree, expressed either as a object or as an oid (string or ). @@ -134,6 +135,8 @@ def read_tree(self, tree): """ repo = self._repo if isinstance(tree, str): + if repo is None: + raise TypeError('id given but no associated repository') tree = repo[tree] if isinstance(tree, Oid): @@ -142,14 +145,14 @@ def read_tree(self, tree): tree = repo[tree] elif not isinstance(tree, Tree): - raise TypeError('argument must be Oid or Tree') + raise TypeError('argument must be Oid, Tree or str') tree_cptr = ffi.new('git_tree **') ffi.buffer(tree_cptr)[:] = tree._pointer[:] err = C.git_index_read_tree(self._index, tree_cptr[0]) check_error(err) - def write_tree(self, repo=None): + def write_tree(self, repo: Repository | None = None) -> Oid: """Create a tree out of the Index. Return the object of the written tree. @@ -172,23 +175,23 @@ def write_tree(self, repo=None): check_error(err) return Oid(raw=bytes(ffi.buffer(coid)[:])) - def remove(self, path, level=0): + def remove(self, path: PathLike[str] | str, level: int = 0) -> None: """Remove an entry from the Index.""" err = C.git_index_remove(self._index, to_bytes(path), level) check_error(err, io=True) - def remove_directory(self, path, level=0): + def remove_directory(self, path: PathLike[str] | str, level: int = 0) -> None: """Remove a directory from the Index.""" err = C.git_index_remove_directory(self._index, to_bytes(path), level) check_error(err, io=True) - def remove_all(self, pathspecs): + def remove_all(self, pathspecs: typing.Sequence[str | PathLike[str]]) -> None: """Remove all index entries matching pathspecs.""" with StrArray(pathspecs) as arr: err = C.git_index_remove_all(self._index, arr.ptr, ffi.NULL, ffi.NULL) check_error(err, io=True) - def add_all(self, pathspecs=None): + def add_all(self, pathspecs: None | list[str | PathLike[str]] = None) -> None: """Add or update index entries matching files in the working directory. If pathspecs are specified, only files matching those pathspecs will @@ -199,7 +202,7 @@ def add_all(self, pathspecs=None): err = C.git_index_add_all(self._index, arr.ptr, 0, ffi.NULL, ffi.NULL) check_error(err, io=True) - def add(self, path_or_entry): + def add(self, path_or_entry: 'IndexEntry | str | PathLike[str]') -> None: """Add or update an entry in the Index. If a path is given, that file will be added. The path must be relative @@ -217,11 +220,13 @@ def add(self, path_or_entry): path = path_or_entry err = C.git_index_add_bypath(self._index, to_bytes(path)) else: - raise TypeError('argument must be string or IndexEntry') + raise TypeError('argument must be string, Path or IndexEntry') check_error(err, io=True) - def add_conflict(self, ancestor, ours, theirs): + def add_conflict( + self, ancestor: 'IndexEntry', ours: 'IndexEntry', theirs: 'IndexEntry | None' + ) -> None: """ Add or update index entries to represent a conflict. Any staged entries that exist at the given paths will be removed. @@ -243,7 +248,9 @@ def add_conflict(self, ancestor, ours, theirs): if theirs and not isinstance(theirs, IndexEntry): raise TypeError('theirs has to be an instance of IndexEntry or None') - centry_ancestor = centry_ours = centry_theirs = ffi.NULL + centry_ancestor: ffi.NULL_TYPE | ffi.GitIndexEntryC = ffi.NULL + centry_ours: ffi.NULL_TYPE | ffi.GitIndexEntryC = ffi.NULL + centry_theirs: ffi.NULL_TYPE | ffi.GitIndexEntryC = ffi.NULL if ancestor is not None: centry_ancestor, _ = ancestor._to_c() if ours is not None: @@ -418,7 +425,7 @@ def _from_c(cls, centry): class IndexEntry: - path: str + path: str | PathLike[str] 'The path of this entry' id: Oid @@ -427,7 +434,9 @@ class IndexEntry: mode: FileMode 'The mode of this entry, a FileMode value' - def __init__(self, path, object_id: Oid, mode: FileMode): + def __init__( + self, path: str | PathLike[str], object_id: Oid, mode: FileMode + ) -> None: self.path = path self.id = object_id self.mode = mode @@ -459,7 +468,7 @@ def __eq__(self, other): self.path == other.path and self.id == other.id and self.mode == other.mode ) - def _to_c(self): + def _to_c(self) -> tuple['ffi.GitIndexEntryC', 'ffi.ArrayC[ffi.char]']: """Convert this entry into the C structure The first returned arg is the pointer, the second is the reference to diff --git a/pygit2/repository.py b/pygit2/repository.py index 7e0b6362..df54065f 100644 --- a/pygit2/repository.py +++ b/pygit2/repository.py @@ -670,7 +670,9 @@ def index(self): # Merging # @staticmethod - def _merge_options(favor: MergeFavor, flags: MergeFlag, file_flags: MergeFileFlag): + def _merge_options( + favor: int | MergeFavor, flags: int | MergeFlag, file_flags: int | MergeFileFlag + ): """Return a 'git_merge_opts *'""" # Check arguments type @@ -1183,7 +1185,7 @@ def stash( if paths: arr = StrArray(paths) - opts.paths = arr.ptr[0] + opts.paths = arr.ptr[0] # type: ignore[index] coid = ffi.new('git_oid *') err = C.git_stash_save_with_opts(coid, self._repo, opts) diff --git a/pygit2/utils.py b/pygit2/utils.py index e2d4b4c4..8b01d3a1 100644 --- a/pygit2/utils.py +++ b/pygit2/utils.py @@ -25,11 +25,25 @@ import contextlib import os -from typing import Generic, Iterator, Protocol, TypeVar, Union, overload +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Generic, + Iterator, + Optional, + Protocol, + Sequence, + TypeVar, + Union, + overload, +) # Import from pygit2 from .ffi import C, ffi +if TYPE_CHECKING: + from ._libgit2.ffi import ArrayC, GitStrrayC, char + def maybe_string(ptr): if not ptr: @@ -130,7 +144,11 @@ class StrArray: contents of 'struct' only remain valid within the StrArray context. """ - def __init__(self, lst): + __array: 'GitStrrayC | ffi.NULL_TYPE' + __strings: list['None | ArrayC[char]'] + __arr: 'ArrayC[char]' + + def __init__(self, lst: None | Sequence[str | os.PathLike[str]]): # Allow passing in None as lg2 typically considers them the same as empty if lst is None: self.__array = ffi.NULL @@ -139,7 +157,7 @@ def __init__(self, lst): if not isinstance(lst, (list, tuple)): raise TypeError('Value must be a list') - strings = [None] * len(lst) + strings: list[None | 'ArrayC[char]'] = [None] * len(lst) for i in range(len(lst)): li = lst[i] if not isinstance(li, str) and not hasattr(li, '__fspath__'): @@ -147,21 +165,26 @@ def __init__(self, lst): strings[i] = ffi.new('char []', to_bytes(li)) - self.__arr = ffi.new('char *[]', strings) + self.__arr = ffi.new('char *[]', strings) # type: ignore[call-overload] self.__strings = strings - self.__array = ffi.new('git_strarray *', [self.__arr, len(strings)]) + self.__array = ffi.new('git_strarray *', [self.__arr, len(strings)]) # type: ignore[call-overload] - def __enter__(self): + def __enter__(self) -> 'StrArray': return self - def __exit__(self, type, value, traceback): + def __exit__( + self, + exc_type: Optional[type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> None: pass @property - def ptr(self): + def ptr(self) -> 'GitStrrayC | ffi.NULL_TYPE': return self.__array - def assign_to(self, git_strarray): + def assign_to(self, git_strarray: 'GitStrrayC') -> None: if self.__array == ffi.NULL: git_strarray.strings = ffi.NULL git_strarray.count = 0 diff --git a/test/test_describe.py b/test/test_describe.py index f24c60a9..963649b4 100644 --- a/test/test_describe.py +++ b/test/test_describe.py @@ -28,10 +28,11 @@ import pytest import pygit2 +from pygit2 import Oid, Repository from pygit2.enums import DescribeStrategy, ObjectType -def add_tag(repo, name, target): +def add_tag(repo: Repository, name: str, target: str) -> Oid: message = 'Example tag.\n' tagger = pygit2.Signature('John Doe', 'jdoe@example.com', 12347, 0) @@ -39,21 +40,21 @@ def add_tag(repo, name, target): return sha -def test_describe(testrepo): +def test_describe(testrepo: Repository) -> None: add_tag(testrepo, 'thetag', '4ec4389a8068641da2d6578db0419484972284c8') assert 'thetag-2-g2be5719' == testrepo.describe() -def test_describe_without_ref(testrepo): +def test_describe_without_ref(testrepo: Repository) -> None: with pytest.raises(pygit2.GitError): testrepo.describe() -def test_describe_default_oid(testrepo): +def test_describe_default_oid(testrepo: Repository) -> None: assert '2be5719' == testrepo.describe(show_commit_oid_as_fallback=True) -def test_describe_strategies(testrepo): +def test_describe_strategies(testrepo: Repository) -> None: assert 'heads/master' == testrepo.describe(describe_strategy=DescribeStrategy.ALL) testrepo.create_reference( @@ -66,14 +67,14 @@ def test_describe_strategies(testrepo): ) -def test_describe_pattern(testrepo): +def test_describe_pattern(testrepo: Repository) -> None: add_tag(testrepo, 'private/tag1', '5ebeeebb320790caf276b9fc8b24546d63316533') add_tag(testrepo, 'public/tag2', '4ec4389a8068641da2d6578db0419484972284c8') assert 'public/tag2-2-g2be5719' == testrepo.describe(pattern='public/*') -def test_describe_committish(testrepo): +def test_describe_committish(testrepo: Repository) -> None: add_tag(testrepo, 'thetag', 'acecd5ea2924a4b900e7e149496e1f4b57976e51') assert 'thetag-4-g2be5719' == testrepo.describe(committish='HEAD') assert 'thetag-1-g5ebeeeb' == testrepo.describe(committish='HEAD^') @@ -86,28 +87,28 @@ def test_describe_committish(testrepo): assert 'thetag-1-g6aaa262' == testrepo.describe(committish='6aaa262') -def test_describe_follows_first_branch_only(testrepo): +def test_describe_follows_first_branch_only(testrepo: Repository) -> None: add_tag(testrepo, 'thetag', '4ec4389a8068641da2d6578db0419484972284c8') with pytest.raises(KeyError): testrepo.describe(only_follow_first_parent=True) -def test_describe_abbreviated_size(testrepo): +def test_describe_abbreviated_size(testrepo: Repository) -> None: add_tag(testrepo, 'thetag', '4ec4389a8068641da2d6578db0419484972284c8') assert 'thetag-2-g2be5719152d4f82c' == testrepo.describe(abbreviated_size=16) assert 'thetag' == testrepo.describe(abbreviated_size=0) -def test_describe_long_format(testrepo): +def test_describe_long_format(testrepo: Repository) -> None: add_tag(testrepo, 'thetag', '2be5719152d4f82c7302b1c0932d8e5f0a4a0e98') assert 'thetag-0-g2be5719' == testrepo.describe(always_use_long_format=True) -def test_describe_dirty(dirtyrepo): +def test_describe_dirty(dirtyrepo: Repository) -> None: add_tag(dirtyrepo, 'thetag', 'a763aa560953e7cfb87ccbc2f536d665aa4dff22') assert 'thetag' == dirtyrepo.describe() -def test_describe_dirty_with_suffix(dirtyrepo): +def test_describe_dirty_with_suffix(dirtyrepo: Repository) -> None: add_tag(dirtyrepo, 'thetag', 'a763aa560953e7cfb87ccbc2f536d665aa4dff22') assert 'thetag-dirty' == dirtyrepo.describe(dirty_suffix='-dirty') diff --git a/test/test_diff.py b/test/test_diff.py index 5deb3481..dea3a92e 100644 --- a/test/test_diff.py +++ b/test/test_diff.py @@ -27,10 +27,12 @@ import textwrap from itertools import chain +from typing import Iterator import pytest import pygit2 +from pygit2 import Diff, Repository from pygit2.enums import DeltaStatus, DiffFlag, DiffOption, DiffStatsFormat, FileMode COMMIT_SHA1_1 = '5fe808e8953c12735680c257f56600cb0de44b10' @@ -169,7 +171,7 @@ """ -def test_diff_empty_index(dirtyrepo): +def test_diff_empty_index(dirtyrepo: Repository) -> None: repo = dirtyrepo head = repo[repo.lookup_reference('HEAD').resolve().target] @@ -182,7 +184,7 @@ def test_diff_empty_index(dirtyrepo): assert DIFF_HEAD_TO_INDEX_EXPECTED == files -def test_workdir_to_tree(dirtyrepo): +def test_workdir_to_tree(dirtyrepo: Repository) -> None: repo = dirtyrepo head = repo[repo.lookup_reference('HEAD').resolve().target] @@ -195,22 +197,22 @@ def test_workdir_to_tree(dirtyrepo): assert DIFF_HEAD_TO_WORKDIR_EXPECTED == files -def test_index_to_workdir(dirtyrepo): +def test_index_to_workdir(dirtyrepo: Repository) -> None: diff = dirtyrepo.diff() files = [patch.delta.new_file.path for patch in diff] assert DIFF_INDEX_TO_WORK_EXPECTED == files -def test_diff_invalid(barerepo): +def test_diff_invalid(barerepo: Repository) -> None: commit_a = barerepo[COMMIT_SHA1_1] commit_b = barerepo[COMMIT_SHA1_2] with pytest.raises(TypeError): - commit_a.tree.diff_to_tree(commit_b) + commit_a.tree.diff_to_tree(commit_b) # type: ignore with pytest.raises(TypeError): - commit_a.tree.diff_to_index(commit_b) + commit_a.tree.diff_to_index(commit_b) # type: ignore -def test_diff_empty_index_bare(barerepo): +def test_diff_empty_index_bare(barerepo: Repository) -> None: repo = barerepo head = repo[repo.lookup_reference('HEAD').resolve().target] @@ -227,11 +229,11 @@ def test_diff_empty_index_bare(barerepo): assert [x.name for x in head.tree] == files -def test_diff_tree(barerepo): +def test_diff_tree(barerepo: Repository) -> None: commit_a = barerepo[COMMIT_SHA1_1] commit_b = barerepo[COMMIT_SHA1_2] - def _test(diff): + def _test(diff: Diff) -> None: assert diff is not None assert 2 == sum(map(lambda x: len(x.hunks), diff)) @@ -260,11 +262,11 @@ def _test(diff): _test(barerepo.diff(COMMIT_SHA1_1, COMMIT_SHA1_2)) -def test_diff_empty_tree(barerepo): +def test_diff_empty_tree(barerepo: Repository) -> None: commit_a = barerepo[COMMIT_SHA1_1] diff = commit_a.tree.diff_to_tree() - def get_context_for_lines(diff): + def get_context_for_lines(diff: Diff) -> Iterator[str]: hunks = chain.from_iterable(map(lambda x: x.hunks, diff)) lines = chain.from_iterable(map(lambda x: x.lines, hunks)) return map(lambda x: x.origin, lines) @@ -279,12 +281,12 @@ def get_context_for_lines(diff): assert all('+' == x for x in get_context_for_lines(diff_swaped)) -def test_diff_revparse(barerepo): +def test_diff_revparse(barerepo: Repository) -> None: diff = barerepo.diff('HEAD', 'HEAD~6') assert type(diff) is pygit2.Diff -def test_diff_tree_opts(barerepo): +def test_diff_tree_opts(barerepo: Repository) -> None: commit_c = barerepo[COMMIT_SHA1_3] commit_d = barerepo[COMMIT_SHA1_4] @@ -298,7 +300,7 @@ def test_diff_tree_opts(barerepo): assert 1 == len(diff[0].hunks) -def test_diff_merge(barerepo): +def test_diff_merge(barerepo: Repository) -> None: commit_a = barerepo[COMMIT_SHA1_1] commit_b = barerepo[COMMIT_SHA1_2] commit_c = barerepo[COMMIT_SHA1_3] @@ -325,7 +327,7 @@ def test_diff_merge(barerepo): assert patch.delta.new_file.path == 'a' -def test_diff_patch(barerepo): +def test_diff_patch(barerepo: Repository) -> None: commit_a = barerepo[COMMIT_SHA1_1] commit_b = barerepo[COMMIT_SHA1_2] @@ -334,7 +336,7 @@ def test_diff_patch(barerepo): assert len(diff) == len([patch for patch in diff]) -def test_diff_ids(barerepo): +def test_diff_ids(barerepo: Repository) -> None: commit_a = barerepo[COMMIT_SHA1_1] commit_b = barerepo[COMMIT_SHA1_2] patch = commit_a.tree.diff_to_tree(commit_b.tree)[0] @@ -343,7 +345,7 @@ def test_diff_ids(barerepo): assert delta.new_file.id == 'af431f20fc541ed6d5afede3e2dc7160f6f01f16' -def test_diff_patchid(barerepo): +def test_diff_patchid(barerepo: Repository) -> None: commit_a = barerepo[COMMIT_SHA1_1] commit_b = barerepo[COMMIT_SHA1_2] diff = commit_a.tree.diff_to_tree(commit_b.tree) @@ -351,7 +353,7 @@ def test_diff_patchid(barerepo): assert diff.patchid == PATCHID -def test_hunk_content(barerepo): +def test_hunk_content(barerepo: Repository) -> None: commit_a = barerepo[COMMIT_SHA1_1] commit_b = barerepo[COMMIT_SHA1_2] patch = commit_a.tree.diff_to_tree(commit_b.tree)[0] @@ -362,7 +364,7 @@ def test_hunk_content(barerepo): assert line.content == line.raw_content.decode() -def test_find_similar(barerepo): +def test_find_similar(barerepo: Repository) -> None: commit_a = barerepo[COMMIT_SHA1_6] commit_b = barerepo[COMMIT_SHA1_7] @@ -376,7 +378,7 @@ def test_find_similar(barerepo): assert any(x.delta.status_char() == 'R' for x in diff) -def test_diff_stats(barerepo): +def test_diff_stats(barerepo: Repository) -> None: commit_a = barerepo[COMMIT_SHA1_1] commit_b = barerepo[COMMIT_SHA1_2] @@ -391,7 +393,7 @@ def test_diff_stats(barerepo): assert STATS_EXPECTED == formatted -def test_deltas(barerepo): +def test_deltas(barerepo: Repository) -> None: commit_a = barerepo[COMMIT_SHA1_1] commit_b = barerepo[COMMIT_SHA1_2] diff = commit_a.tree.diff_to_tree(commit_b.tree) @@ -414,7 +416,7 @@ def test_deltas(barerepo): # assert delta.flags == patch_delta.flags -def test_diff_parse(barerepo): +def test_diff_parse(barerepo: Repository) -> None: diff = pygit2.Diff.parse_diff(PATCH) stats = diff.stats @@ -426,9 +428,9 @@ def test_diff_parse(barerepo): assert 2 == len(deltas) -def test_parse_diff_null(): +def test_parse_diff_null() -> None: with pytest.raises(TypeError): - pygit2.Diff.parse_diff(None) + pygit2.Diff.parse_diff(None) # type: ignore def test_parse_diff_bad(): @@ -445,7 +447,7 @@ def test_parse_diff_bad(): pygit2.Diff.parse_diff(diff) -def test_diff_blobs(emptyrepo): +def test_diff_blobs(emptyrepo: Repository) -> None: repo = emptyrepo blob1 = repo.create_blob(TEXT_BLOB1.encode()) blob2 = repo.create_blob(TEXT_BLOB2.encode()) diff --git a/test/test_diff_binary.py b/test/test_diff_binary.py index e23583ad..9eb3a38c 100644 --- a/test/test_diff_binary.py +++ b/test/test_diff_binary.py @@ -23,16 +23,20 @@ # the Free Software Foundation, 51 Franklin Street, Fifth Floor, # Boston, MA 02110-1301, USA. +from pathlib import Path +from typing import Generator + import pytest import pygit2 +from pygit2 import Repository from pygit2.enums import DiffOption from . import utils @pytest.fixture -def repo(tmp_path): +def repo(tmp_path: Path) -> Generator[Repository, None, None]: with utils.TemporaryRepository('binaryfilerepo.zip', tmp_path) as path: yield pygit2.Repository(path) @@ -54,7 +58,7 @@ def repo(tmp_path): """ -def test_binary_diff(repo): +def test_binary_diff(repo: Repository) -> None: diff = repo.diff('HEAD', 'HEAD^') assert PATCH_BINARY == diff.patch diff = repo.diff('HEAD', 'HEAD^', flags=DiffOption.SHOW_BINARY) diff --git a/test/test_filter.py b/test/test_filter.py index 5fedb22b..d29b8b7e 100644 --- a/test/test_filter.py +++ b/test/test_filter.py @@ -1,40 +1,52 @@ import codecs from io import BytesIO +from typing import Callable, Generator import pytest import pygit2 +from pygit2 import Blob, Filter, FilterSource, Repository from pygit2.enums import BlobFilter from pygit2.errors import Passthrough -def _rot13(data): +def _rot13(data: bytes) -> bytes: return codecs.encode(data.decode('utf-8'), 'rot_13').encode('utf-8') class _Rot13Filter(pygit2.Filter): attributes = 'text' - def write(self, data, src, write_next): + def write( + self, + data: bytes, + src: FilterSource, + write_next: Callable[[bytes], None], + ) -> None: return super().write(_rot13(data), src, write_next) class _BufferedFilter(pygit2.Filter): attributes = 'text' - def __init__(self): + def __init__(self) -> None: super().__init__() self.buf = BytesIO() - def write(self, data, src, write_next): + def write( + self, + data: bytes, + src: FilterSource, + write_next: Callable[[bytes], None], + ) -> None: self.buf.write(data) - def close(self, write_next): + def close(self, write_next: Callable[[bytes], None]) -> None: write_next(_rot13(self.buf.getvalue())) class _PassthroughFilter(_Rot13Filter): - def check(self, src, attr_values): + def check(self, src: FilterSource, attr_values: list[str | None]) -> None: assert attr_values == [None] assert src.repo raise Passthrough @@ -45,36 +57,37 @@ class _UnmatchedFilter(_Rot13Filter): @pytest.fixture -def rot13_filter(): +def rot13_filter() -> Generator[None, None, None]: pygit2.filter_register('rot13', _Rot13Filter) yield pygit2.filter_unregister('rot13') @pytest.fixture -def passthrough_filter(): +def passthrough_filter() -> Generator[None, None, None]: pygit2.filter_register('passthrough-rot13', _PassthroughFilter) yield pygit2.filter_unregister('passthrough-rot13') @pytest.fixture -def buffered_filter(): +def buffered_filter() -> Generator[None, None, None]: pygit2.filter_register('buffered-rot13', _BufferedFilter) yield pygit2.filter_unregister('buffered-rot13') @pytest.fixture -def unmatched_filter(): +def unmatched_filter() -> Generator[None, None, None]: pygit2.filter_register('unmatched-rot13', _UnmatchedFilter) yield pygit2.filter_unregister('unmatched-rot13') -def test_filter(testrepo, rot13_filter): +def test_filter(testrepo: Repository, rot13_filter: Filter) -> None: blob_oid = testrepo.create_blob_fromworkdir('bye.txt') blob = testrepo[blob_oid] + assert isinstance(blob, Blob) flags = BlobFilter.CHECK_FOR_BINARY | BlobFilter.ATTRIBUTES_FROM_HEAD assert b'olr jbeyq\n' == blob.data with pygit2.BlobIO(blob) as reader: @@ -83,9 +96,10 @@ def test_filter(testrepo, rot13_filter): assert b'bye world\n' == reader.read() -def test_filter_buffered(testrepo, buffered_filter): +def test_filter_buffered(testrepo: Repository, buffered_filter: Filter) -> None: blob_oid = testrepo.create_blob_fromworkdir('bye.txt') blob = testrepo[blob_oid] + assert isinstance(blob, Blob) flags = BlobFilter.CHECK_FOR_BINARY | BlobFilter.ATTRIBUTES_FROM_HEAD assert b'olr jbeyq\n' == blob.data with pygit2.BlobIO(blob) as reader: @@ -94,9 +108,10 @@ def test_filter_buffered(testrepo, buffered_filter): assert b'bye world\n' == reader.read() -def test_filter_passthrough(testrepo, passthrough_filter): +def test_filter_passthrough(testrepo: Repository, passthrough_filter: Filter) -> None: blob_oid = testrepo.create_blob_fromworkdir('bye.txt') blob = testrepo[blob_oid] + assert isinstance(blob, Blob) flags = BlobFilter.CHECK_FOR_BINARY | BlobFilter.ATTRIBUTES_FROM_HEAD assert b'bye world\n' == blob.data with pygit2.BlobIO(blob) as reader: @@ -105,9 +120,10 @@ def test_filter_passthrough(testrepo, passthrough_filter): assert b'bye world\n' == reader.read() -def test_filter_unmatched(testrepo, unmatched_filter): +def test_filter_unmatched(testrepo: Repository, unmatched_filter: Filter) -> None: blob_oid = testrepo.create_blob_fromworkdir('bye.txt') blob = testrepo[blob_oid] + assert isinstance(blob, Blob) flags = BlobFilter.CHECK_FOR_BINARY | BlobFilter.ATTRIBUTES_FROM_HEAD assert b'bye world\n' == blob.data with pygit2.BlobIO(blob) as reader: @@ -116,7 +132,7 @@ def test_filter_unmatched(testrepo, unmatched_filter): assert b'bye world\n' == reader.read() -def test_filter_cleanup(dirtyrepo, rot13_filter): +def test_filter_cleanup(dirtyrepo: Repository, rot13_filter: Filter) -> None: # Indirectly test that pygit2_filter_cleanup has the GIL # before calling pygit2_filter_payload_free. dirtyrepo.diff() diff --git a/test/test_index.py b/test/test_index.py index b1fccba9..086863b7 100644 --- a/test/test_index.py +++ b/test/test_index.py @@ -30,21 +30,21 @@ import pytest import pygit2 -from pygit2 import Index, IndexEntry, Oid, Repository +from pygit2 import Index, IndexEntry, Oid, Repository, Tree from pygit2.enums import FileMode from . import utils -def test_bare(barerepo): +def test_bare(barerepo: Repository) -> None: assert len(barerepo.index) == 0 -def test_index(testrepo): +def test_index(testrepo: Repository) -> None: assert testrepo.index is not None -def test_read(testrepo): +def test_read(testrepo: Repository) -> None: index = testrepo.index assert len(index) == 2 @@ -60,7 +60,7 @@ def test_read(testrepo): assert index[1].id == sha -def test_add(testrepo): +def test_add(testrepo: Repository) -> None: index = testrepo.index sha = '0907563af06c7464d62a70cdd135a6ba7d2b41d8' @@ -71,7 +71,7 @@ def test_add(testrepo): assert index['bye.txt'].id == sha -def test_add_aspath(testrepo): +def test_add_aspath(testrepo: Repository) -> None: index = testrepo.index assert 'bye.txt' not in index @@ -79,7 +79,7 @@ def test_add_aspath(testrepo): assert 'bye.txt' in index -def test_add_all(testrepo): +def test_add_all(testrepo: Repository) -> None: clear(testrepo) sha_bye = '0907563af06c7464d62a70cdd135a6ba7d2b41d8' @@ -113,7 +113,7 @@ def test_add_all(testrepo): assert index['hello.txt'].id == sha_hello -def test_add_all_aspath(testrepo): +def test_add_all_aspath(testrepo: Repository) -> None: clear(testrepo) index = testrepo.index @@ -122,14 +122,14 @@ def test_add_all_aspath(testrepo): assert 'hello.txt' in index -def clear(repo): +def clear(repo: Repository) -> None: index = repo.index assert len(index) == 2 index.clear() assert len(index) == 0 -def test_write(testrepo): +def test_write(testrepo: Repository) -> None: index = testrepo.index index.add('bye.txt') index.write() @@ -140,7 +140,7 @@ def test_write(testrepo): assert 'bye.txt' in index -def test_read_tree(testrepo): +def test_read_tree(testrepo: Repository) -> None: tree_oid = '68aba62e560c0ebc3396e8ae9335232cd93a3f60' # Test reading first tree index = testrepo.index @@ -154,11 +154,11 @@ def test_read_tree(testrepo): assert len(index) == 2 -def test_write_tree(testrepo): +def test_write_tree(testrepo: Repository) -> None: assert testrepo.index.write_tree() == 'fd937514cb799514d4b81bb24c5fcfeb6472b245' -def test_iter(testrepo): +def test_iter(testrepo: Repository) -> None: index = testrepo.index n = len(index) assert len(list(index)) == n @@ -168,7 +168,7 @@ def test_iter(testrepo): assert list(x.id for x in index) == entries -def test_mode(testrepo): +def test_mode(testrepo: Repository) -> None: """ Testing that we can access an index entry mode. """ @@ -178,7 +178,7 @@ def test_mode(testrepo): assert hello_mode == 33188 -def test_bare_index(testrepo): +def test_bare_index(testrepo: Repository) -> None: index = pygit2.Index(Path(testrepo.path) / 'index') assert [x.id for x in index] == [x.id for x in testrepo.index] @@ -186,21 +186,21 @@ def test_bare_index(testrepo): index.add('bye.txt') -def test_remove(testrepo): +def test_remove(testrepo: Repository) -> None: index = testrepo.index assert 'hello.txt' in index index.remove('hello.txt') assert 'hello.txt' not in index -def test_remove_directory(dirtyrepo): +def test_remove_directory(dirtyrepo: Repository) -> None: index = dirtyrepo.index assert 'subdir/current_file' in index index.remove_directory('subdir') assert 'subdir/current_file' not in index -def test_remove_all(testrepo): +def test_remove_all(testrepo: Repository) -> None: index = testrepo.index assert 'hello.txt' in index index.remove_all(['*.txt']) @@ -209,28 +209,28 @@ def test_remove_all(testrepo): index.remove_all(['not-existing']) # this doesn't error -def test_remove_aspath(testrepo): +def test_remove_aspath(testrepo: Repository) -> None: index = testrepo.index assert 'hello.txt' in index index.remove(Path('hello.txt')) assert 'hello.txt' not in index -def test_remove_directory_aspath(dirtyrepo): +def test_remove_directory_aspath(dirtyrepo: Repository) -> None: index = dirtyrepo.index assert 'subdir/current_file' in index index.remove_directory(Path('subdir')) assert 'subdir/current_file' not in index -def test_remove_all_aspath(testrepo): +def test_remove_all_aspath(testrepo: Repository) -> None: index = testrepo.index assert 'hello.txt' in index index.remove_all([Path('hello.txt')]) assert 'hello.txt' not in index -def test_change_attributes(testrepo): +def test_change_attributes(testrepo: Repository) -> None: index = testrepo.index entry = index['hello.txt'] ign_entry = index['.gitignore'] @@ -244,7 +244,7 @@ def test_change_attributes(testrepo): assert FileMode.BLOB_EXECUTABLE == entry.mode -def test_write_tree_to(testrepo, tmp_path): +def test_write_tree_to(testrepo: Repository, tmp_path: Path) -> None: pygit2.option(pygit2.enums.Option.ENABLE_STRICT_OBJECT_CREATION, False) with utils.TemporaryRepository('emptyrepo.zip', tmp_path) as path: nrepo = Repository(path) @@ -252,7 +252,7 @@ def test_write_tree_to(testrepo, tmp_path): assert nrepo[id] is not None -def test_create_entry(testrepo): +def test_create_entry(testrepo: Repository) -> None: index = testrepo.index hello_entry = index['hello.txt'] entry = pygit2.IndexEntry('README.md', hello_entry.id, hello_entry.mode) @@ -260,7 +260,7 @@ def test_create_entry(testrepo): assert '60e769e57ae1d6a2ab75d8d253139e6260e1f912' == index.write_tree() -def test_create_entry_aspath(testrepo): +def test_create_entry_aspath(testrepo: Repository) -> None: index = testrepo.index hello_entry = index[Path('hello.txt')] entry = pygit2.IndexEntry(Path('README.md'), hello_entry.id, hello_entry.mode) @@ -268,7 +268,7 @@ def test_create_entry_aspath(testrepo): index.write_tree() -def test_entry_eq(testrepo): +def test_entry_eq(testrepo: Repository) -> None: index = testrepo.index hello_entry = index['hello.txt'] entry = pygit2.IndexEntry(hello_entry.path, hello_entry.id, hello_entry.mode) @@ -285,7 +285,7 @@ def test_entry_eq(testrepo): assert hello_entry != entry -def test_entry_repr(testrepo): +def test_entry_repr(testrepo: Repository) -> None: index = testrepo.index hello_entry = index['hello.txt'] assert ( @@ -306,16 +306,18 @@ def test_create_empty_read_tree_as_string(): index = Index() # no repo associated, so we don't know where to read from with pytest.raises(TypeError): - index('read_tree', 'fd937514cb799514d4b81bb24c5fcfeb6472b245') + index('read_tree', 'fd937514cb799514d4b81bb24c5fcfeb6472b245') # type: ignore -def test_create_empty_read_tree(testrepo): +def test_create_empty_read_tree(testrepo: Repository) -> None: index = Index() - index.read_tree(testrepo['fd937514cb799514d4b81bb24c5fcfeb6472b245']) + tree = testrepo['fd937514cb799514d4b81bb24c5fcfeb6472b245'] + assert isinstance(tree, Tree) + index.read_tree(tree) @utils.fails_in_macos -def test_add_conflict(testrepo): +def test_add_conflict(testrepo: Repository) -> None: ancestor_blob_id = testrepo.create_blob('ancestor') ancestor = IndexEntry('conflict.txt', ancestor_blob_id, FileMode.BLOB_EXECUTABLE) diff --git a/test/test_mailmap.py b/test/test_mailmap.py index e5a3b90e..44da270f 100644 --- a/test/test_mailmap.py +++ b/test/test_mailmap.py @@ -62,14 +62,14 @@ ] -def test_empty(): +def test_empty() -> None: mailmap = Mailmap() for _, _, name, email in TEST_RESOLVE: assert mailmap.resolve(name, email) == (name, email) -def test_new(): +def test_new() -> None: mailmap = Mailmap() # Add entries to the mailmap @@ -80,7 +80,7 @@ def test_new(): assert mailmap.resolve(name, email) == (real_name, real_email) -def test_parsed(): +def test_parsed() -> None: mailmap = Mailmap.from_buffer(TEST_MAILMAP) for real_name, real_email, name, email in TEST_RESOLVE: diff --git a/test/test_merge.py b/test/test_merge.py index 538070f0..492c0034 100644 --- a/test/test_merge.py +++ b/test/test_merge.py @@ -30,28 +30,29 @@ import pytest import pygit2 +from pygit2 import Repository from pygit2.enums import FileStatus, MergeAnalysis, MergeFavor, MergeFileFlag, MergeFlag @pytest.mark.parametrize('id', [None, 42]) -def test_merge_invalid_type(mergerepo, id): +def test_merge_invalid_type(mergerepo: Repository, id: None | int) -> None: with pytest.raises(TypeError): - mergerepo.merge(id) + mergerepo.merge(id) # type:ignore # TODO: Once Repository.merge drops support for str arguments, # add an extra parameter to test_merge_invalid_type above # to make sure we cover legacy code. -def test_merge_string_argument_deprecated(mergerepo): +def test_merge_string_argument_deprecated(mergerepo: Repository) -> None: branch_head_hex = '5ebeeebb320790caf276b9fc8b24546d63316533' with pytest.warns(DeprecationWarning, match=r'Pass Commit.+instead'): mergerepo.merge(branch_head_hex) -def test_merge_analysis_uptodate(mergerepo): +def test_merge_analysis_uptodate(mergerepo: Repository) -> None: branch_head_hex = '5ebeeebb320790caf276b9fc8b24546d63316533' - branch_id = mergerepo.get(branch_head_hex).id + branch_id = mergerepo[branch_head_hex].id analysis, preference = mergerepo.merge_analysis(branch_id) assert analysis & MergeAnalysis.UP_TO_DATE @@ -64,9 +65,9 @@ def test_merge_analysis_uptodate(mergerepo): assert {} == mergerepo.status() -def test_merge_analysis_fastforward(mergerepo): +def test_merge_analysis_fastforward(mergerepo: Repository) -> None: branch_head_hex = 'e97b4cfd5db0fb4ebabf4f203979ca4e5d1c7c87' - branch_id = mergerepo.get(branch_head_hex).id + branch_id = mergerepo[branch_head_hex].id analysis, preference = mergerepo.merge_analysis(branch_id) assert not analysis & MergeAnalysis.UP_TO_DATE @@ -79,9 +80,9 @@ def test_merge_analysis_fastforward(mergerepo): assert {} == mergerepo.status() -def test_merge_no_fastforward_no_conflicts(mergerepo): +def test_merge_no_fastforward_no_conflicts(mergerepo: Repository) -> None: branch_head_hex = '03490f16b15a09913edb3a067a3dc67fbb8d41f1' - branch_id = mergerepo.get(branch_head_hex).id + branch_id = mergerepo[branch_head_hex].id analysis, preference = mergerepo.merge_analysis(branch_id) assert not analysis & MergeAnalysis.UP_TO_DATE assert not analysis & MergeAnalysis.FASTFORWARD @@ -90,7 +91,7 @@ def test_merge_no_fastforward_no_conflicts(mergerepo): assert {} == mergerepo.status() -def test_merge_invalid_hex(mergerepo): +def test_merge_invalid_hex(mergerepo: Repository) -> None: branch_head_hex = '12345678' with ( pytest.raises(KeyError), @@ -99,9 +100,9 @@ def test_merge_invalid_hex(mergerepo): mergerepo.merge(branch_head_hex) -def test_merge_already_something_in_index(mergerepo): +def test_merge_already_something_in_index(mergerepo: Repository) -> None: branch_head_hex = '03490f16b15a09913edb3a067a3dc67fbb8d41f1' - branch_oid = mergerepo.get(branch_head_hex).id + branch_oid = mergerepo[branch_head_hex].id with (Path(mergerepo.workdir) / 'inindex.txt').open('w') as f: f.write('new content') mergerepo.index.add('inindex.txt') @@ -109,9 +110,9 @@ def test_merge_already_something_in_index(mergerepo): mergerepo.merge(branch_oid) -def test_merge_no_fastforward_conflicts(mergerepo): +def test_merge_no_fastforward_conflicts(mergerepo: Repository) -> None: branch_head_hex = '1b2bae55ac95a4be3f8983b86cd579226d0eb247' - branch_id = mergerepo.get(branch_head_hex).id + branch_id = mergerepo[branch_head_hex].id analysis, preference = mergerepo.merge_analysis(branch_id) assert not analysis & MergeAnalysis.UP_TO_DATE @@ -144,7 +145,7 @@ def test_merge_no_fastforward_conflicts(mergerepo): assert {'.gitignore': FileStatus.INDEX_MODIFIED} == mergerepo.status() -def test_merge_remove_conflicts(mergerepo): +def test_merge_remove_conflicts(mergerepo: Repository) -> None: other_branch_tip = pygit2.Oid(hex='1b2bae55ac95a4be3f8983b86cd579226d0eb247') mergerepo.merge(other_branch_tip) idx = mergerepo.index @@ -154,7 +155,7 @@ def test_merge_remove_conflicts(mergerepo): try: conflicts['.gitignore'] except KeyError: - mergerepo.fail("conflicts['.gitignore'] raised KeyError unexpectedly") + mergerepo.fail("conflicts['.gitignore'] raised KeyError unexpectedly") # type: ignore del idx.conflicts['.gitignore'] with pytest.raises(KeyError): conflicts.__getitem__('.gitignore') @@ -170,14 +171,14 @@ def test_merge_remove_conflicts(mergerepo): MergeFavor.UNION, ], ) -def test_merge_favor(mergerepo, favor): +def test_merge_favor(mergerepo: Repository, favor: MergeFavor) -> None: branch_head = pygit2.Oid(hex='1b2bae55ac95a4be3f8983b86cd579226d0eb247') mergerepo.merge(branch_head, favor=favor) assert mergerepo.index.conflicts is None -def test_merge_fail_on_conflict(mergerepo): +def test_merge_fail_on_conflict(mergerepo: Repository) -> None: branch_head = pygit2.Oid(hex='1b2bae55ac95a4be3f8983b86cd579226d0eb247') with pytest.raises(pygit2.GitError, match=r'merge conflicts exist'): @@ -186,7 +187,7 @@ def test_merge_fail_on_conflict(mergerepo): ) -def test_merge_commits(mergerepo): +def test_merge_commits(mergerepo: Repository) -> None: branch_head = pygit2.Oid(hex='03490f16b15a09913edb3a067a3dc67fbb8d41f1') merge_index = mergerepo.merge_commits(mergerepo.head.target, branch_head) @@ -201,7 +202,7 @@ def test_merge_commits(mergerepo): assert merge_tree == merge_commits_tree -def test_merge_commits_favor(mergerepo): +def test_merge_commits_favor(mergerepo: Repository) -> None: branch_head = pygit2.Oid(hex='1b2bae55ac95a4be3f8983b86cd579226d0eb247') merge_index = mergerepo.merge_commits( @@ -211,10 +212,10 @@ def test_merge_commits_favor(mergerepo): # Incorrect favor value with pytest.raises(TypeError, match=r'favor argument must be MergeFavor'): - mergerepo.merge_commits(mergerepo.head.target, branch_head, favor='foo') + mergerepo.merge_commits(mergerepo.head.target, branch_head, favor='foo') # type: ignore -def test_merge_trees(mergerepo): +def test_merge_trees(mergerepo: Repository) -> None: branch_id = pygit2.Oid(hex='03490f16b15a09913edb3a067a3dc67fbb8d41f1') ancestor_id = mergerepo.merge_base(mergerepo.head.target, branch_id) @@ -230,7 +231,7 @@ def test_merge_trees(mergerepo): assert merge_tree == merge_commits_tree -def test_merge_trees_favor(mergerepo): +def test_merge_trees_favor(mergerepo: Repository) -> None: branch_head_hex = '1b2bae55ac95a4be3f8983b86cd579226d0eb247' ancestor_id = mergerepo.merge_base(mergerepo.head.target, branch_head_hex) merge_index = mergerepo.merge_trees( @@ -240,14 +241,19 @@ def test_merge_trees_favor(mergerepo): with pytest.raises(TypeError): mergerepo.merge_trees( - ancestor_id, mergerepo.head.target, branch_head_hex, favor='foo' + ancestor_id, + mergerepo.head.target, + branch_head_hex, + favor='foo', # type: ignore ) -def test_merge_options(): +def test_merge_options() -> None: favor = MergeFavor.OURS - flags = MergeFlag.FIND_RENAMES | MergeFlag.FAIL_ON_CONFLICT - file_flags = MergeFileFlag.IGNORE_WHITESPACE | MergeFileFlag.DIFF_PATIENCE + flags: int | MergeFlag = MergeFlag.FIND_RENAMES | MergeFlag.FAIL_ON_CONFLICT + file_flags: int | MergeFileFlag = ( + MergeFileFlag.IGNORE_WHITESPACE | MergeFileFlag.DIFF_PATIENCE + ) o1 = pygit2.Repository._merge_options( favor=favor, flags=flags, file_flags=file_flags ) @@ -280,9 +286,9 @@ def test_merge_options(): assert file_flags == o1.file_flags -def test_merge_many(mergerepo): +def test_merge_many(mergerepo: Repository) -> None: branch_head_hex = '03490f16b15a09913edb3a067a3dc67fbb8d41f1' - branch_id = mergerepo.get(branch_head_hex).id + branch_id = mergerepo[branch_head_hex].id ancestor_id = mergerepo.merge_base_many([mergerepo.head.target, branch_id]) merge_index = mergerepo.merge_trees( @@ -299,9 +305,9 @@ def test_merge_many(mergerepo): assert merge_tree == merge_commits_tree -def test_merge_octopus(mergerepo): +def test_merge_octopus(mergerepo: Repository) -> None: branch_head_hex = '03490f16b15a09913edb3a067a3dc67fbb8d41f1' - branch_id = mergerepo.get(branch_head_hex).id + branch_id = mergerepo[branch_head_hex].id ancestor_id = mergerepo.merge_base_octopus([mergerepo.head.target, branch_id]) merge_index = mergerepo.merge_trees( @@ -318,7 +324,7 @@ def test_merge_octopus(mergerepo): assert merge_tree == merge_commits_tree -def test_merge_mergeheads(mergerepo): +def test_merge_mergeheads(mergerepo: Repository) -> None: assert mergerepo.listall_mergeheads() == [] branch_head = pygit2.Oid(hex='1b2bae55ac95a4be3f8983b86cd579226d0eb247') @@ -332,7 +338,7 @@ def test_merge_mergeheads(mergerepo): ) -def test_merge_message(mergerepo): +def test_merge_message(mergerepo: Repository) -> None: assert not mergerepo.message assert not mergerepo.raw_message @@ -346,7 +352,7 @@ def test_merge_message(mergerepo): assert not mergerepo.message -def test_merge_remove_message(mergerepo): +def test_merge_remove_message(mergerepo: Repository) -> None: branch_head = pygit2.Oid(hex='1b2bae55ac95a4be3f8983b86cd579226d0eb247') mergerepo.merge(branch_head) @@ -355,7 +361,7 @@ def test_merge_remove_message(mergerepo): assert not mergerepo.message -def test_merge_commit(mergerepo): +def test_merge_commit(mergerepo: Repository) -> None: commit = mergerepo['1b2bae55ac95a4be3f8983b86cd579226d0eb247'] assert isinstance(commit, pygit2.Commit) mergerepo.merge(commit) @@ -364,7 +370,7 @@ def test_merge_commit(mergerepo): assert mergerepo.listall_mergeheads() == [commit.id] -def test_merge_reference(mergerepo): +def test_merge_reference(mergerepo: Repository) -> None: branch = mergerepo.branches.local['branch-conflicts'] branch_head_hex = '1b2bae55ac95a4be3f8983b86cd579226d0eb247' mergerepo.merge(branch) diff --git a/test/utils.py b/test/utils.py index ba98f5ce..f4cadbc8 100644 --- a/test/utils.py +++ b/test/utils.py @@ -32,7 +32,7 @@ import zipfile from pathlib import Path from types import TracebackType -from typing import Optional +from typing import Callable, Optional, ParamSpec, TypeVar # Requirements import pytest @@ -40,6 +40,9 @@ # Pygit2 import pygit2 +T = TypeVar('T') +P = ParamSpec('P') + requires_future_libgit2 = pytest.mark.xfail( pygit2.LIBGIT2_VER < (2, 0, 0), reason='This test may work with a future version of libgit2', @@ -121,7 +124,13 @@ def __exit__( pass -def assertRaisesWithArg(exc_class, arg, func, *args, **kwargs): +def assertRaisesWithArg( + exc_class: type[Exception], + arg: object, + func: Callable[P, T], + *args: P.args, + **kwargs: P.kwargs, +) -> None: with pytest.raises(exc_class) as excinfo: func(*args, **kwargs) assert excinfo.value.args == (arg,)