diff --git a/pygit2/remotes.py b/pygit2/remotes.py index ad34ee45..b3da7855 100644 --- a/pygit2/remotes.py +++ b/pygit2/remotes.py @@ -25,7 +25,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Generator, Iterator, Literal, TypedDict +import warnings +from typing import TYPE_CHECKING, Any, Generator, Iterator, Literal # Import from pygit2 from pygit2 import RemoteCallbacks @@ -50,12 +51,46 @@ from .repository import BaseRepository -class LsRemotesDict(TypedDict): +class RemoteHead: + """ + Description of a reference advertised by a remote server, + given out on `Remote.ls_remotes` calls. + """ + local: bool - loid: None | Oid + """Available locally""" + + oid: Oid + + loid: Oid + name: str | None + symref_target: str | None - oid: Oid + """ + If the server sent a symref mapping for this ref, this will + point to the target. + """ + + def __init__(self, c_struct: Any) -> None: + self.local = bool(c_struct.local) + self.oid = Oid(raw=bytes(ffi.buffer(c_struct.oid.id)[:])) + self.loid = Oid(raw=bytes(ffi.buffer(c_struct.loid.id)[:])) + self.name = maybe_string(c_struct.name) + self.symref_target = maybe_string(c_struct.symref_target) + + def __getitem__(self, item: str) -> Any: + """ + DEPRECATED: Backwards compatibility with legacy user code + that expects this object to be a dictionary with string keys. + """ + warnings.warn( + 'ls_remotes no longer returns a dict. ' + 'Update your code to read from fields instead ' + '(e.g. result["name"] --> result.name)', + DeprecationWarning, + ) + return getattr(self, item) class PushUpdate: @@ -228,10 +263,10 @@ def ls_remotes( callbacks: RemoteCallbacks | None = None, proxy: str | None | bool = None, connect: bool = True, - ) -> list[LsRemotesDict]: + ) -> list[RemoteHead]: """ - Return a list of dicts that maps to `git_remote_head` from a - `ls_remotes` call. + Get the list of references with which the server responds to a new + connection. Parameters: @@ -247,32 +282,14 @@ def ls_remotes( if connect: self.connect(callbacks=callbacks, proxy=proxy) - refs = ffi.new('git_remote_head ***') - refs_len = ffi.new('size_t *') + refs_ptr = ffi.new('git_remote_head ***') + size_ptr = ffi.new('size_t *') - err = C.git_remote_ls(refs, refs_len, self._remote) + err = C.git_remote_ls(refs_ptr, size_ptr, self._remote) check_error(err) - results = [] - for i in range(int(refs_len[0])): - ref = refs[0][i] - local = bool(ref.local) - if local: - loid = Oid(raw=bytes(ffi.buffer(ref.loid.id)[:])) - else: - loid = None - - remote = LsRemotesDict( - { - 'local': local, - 'loid': loid, - 'name': maybe_string(ref.name), - 'symref_target': maybe_string(ref.symref_target), - 'oid': Oid(raw=bytes(ffi.buffer(ref.oid.id)[:])), - } - ) - - results.append(remote) + num_refs = int(size_ptr[0]) + results = [RemoteHead(refs_ptr[0][i]) for i in range(num_refs)] return results diff --git a/test/test_remote.py b/test/test_remote.py index 5a7a5027..d13bb38b 100644 --- a/test/test_remote.py +++ b/test/test_remote.py @@ -201,7 +201,21 @@ def test_ls_remotes(testrepo: Repository) -> None: assert refs # Check that a known ref is returned. - assert next(iter(r for r in refs if r['name'] == 'refs/tags/v0.28.2')) + assert next(iter(r for r in refs if r.name == 'refs/tags/v0.28.2')) + + +@utils.requires_network +def test_ls_remotes_backwards_compatibility(testrepo: Repository) -> None: + assert 1 == len(testrepo.remotes) + remote = testrepo.remotes[0] + refs = remote.ls_remotes() + ref = refs[0] + + for field in ('name', 'oid', 'loid', 'local', 'symref_target'): + new_value = getattr(ref, field) + with pytest.warns(DeprecationWarning, match='no longer returns a dict'): + old_value = ref[field] + assert new_value == old_value @utils.requires_network @@ -217,7 +231,7 @@ def test_ls_remotes_without_implicit_connect(testrepo: Repository) -> None: assert refs # Check that a known ref is returned. - assert next(iter(r for r in refs if r['name'] == 'refs/tags/v0.28.2')) + assert next(iter(r for r in refs if r.name == 'refs/tags/v0.28.2')) def test_remote_collection(testrepo: Repository) -> None: