Skip to content

Commit e3ed43f

Browse files
committed
initial commit of pygit2 callbacks.py
1 parent 40f44c1 commit e3ed43f

File tree

1 file changed

+47
-28
lines changed

1 file changed

+47
-28
lines changed

pygit2/callbacks.py

Lines changed: 47 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,12 @@
6262
API.
6363
"""
6464

65+
from __future__ import annotations
66+
6567
# Standard Library
6668
from contextlib import contextmanager
6769
from functools import wraps
68-
from typing import Optional, Union
70+
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
6971

7072
# pygit2
7173
from ._pygit2 import DiffFile, Oid
@@ -74,19 +76,23 @@
7476
from .ffi import C, ffi
7577
from .utils import StrArray, maybe_string, ptr_to_bytes, to_bytes
7678

79+
if TYPE_CHECKING:
80+
from .remotes import Remote, TransferProgress
81+
from .repository import Repository
82+
7783
#
7884
# The payload is the way to pass information from the pygit2 API, through
7985
# libgit2, to the Python callbacks. And back.
8086
#
8187

8288

8389
class Payload:
84-
def __init__(self, **kw):
90+
def __init__(self, **kw: Any):
8591
for key, value in kw.items():
8692
setattr(self, key, value)
87-
self._stored_exception = None
93+
self._stored_exception: BaseException | None = None
8894

89-
def check_error(self, error_code):
95+
def check_error(self, error_code: int):
9096
if error_code == C.GIT_EUSER:
9197
assert self._stored_exception is not None
9298
raise self._stored_exception
@@ -112,14 +118,18 @@ class RemoteCallbacks(Payload):
112118
RemoteCallbacks(certificate=certificate).
113119
"""
114120

121+
if TYPE_CHECKING:
122+
repository: Callable[[str, bool], Repository]
123+
remote: Callable[[Repository, str, str], Remote]
124+
115125
def __init__(self, credentials=None, certificate_check=None):
116126
super().__init__()
117127
if credentials is not None:
118128
self.credentials = credentials
119129
if certificate_check is not None:
120130
self.certificate_check = certificate_check
121131

122-
def sideband_progress(self, string):
132+
def sideband_progress(self, string: str):
123133
"""
124134
Progress output callback. Override this function with your own
125135
progress reporting function
@@ -158,7 +168,7 @@ def credentials(
158168
"""
159169
raise Passthrough
160170

161-
def certificate_check(self, certificate, valid, host):
171+
def certificate_check(self, certificate: None, valid: bool, host: str):
162172
"""
163173
Certificate callback. Override with your own function to determine
164174
whether to accept the server's certificate.
@@ -180,7 +190,7 @@ def certificate_check(self, certificate, valid, host):
180190

181191
raise Passthrough
182192

183-
def transfer_progress(self, stats):
193+
def transfer_progress(self, stats: TransferProgress):
184194
"""
185195
Transfer progress callback. Override with your own function to report
186196
transfer progress.
@@ -191,7 +201,7 @@ def transfer_progress(self, stats):
191201
The progress up to now.
192202
"""
193203

194-
def update_tips(self, refname, old, new):
204+
def update_tips(self, refname: str, old: Oid, new: Oid):
195205
"""
196206
Update tips callback. Override with your own function to report
197207
reference updates.
@@ -208,7 +218,7 @@ def update_tips(self, refname, old, new):
208218
The reference's new value.
209219
"""
210220

211-
def push_update_reference(self, refname, message):
221+
def push_update_reference(self, refname: str, message: str | None):
212222
"""
213223
Push update reference callback. Override with your own function to
214224
report the remote's acceptance or rejection of reference updates.
@@ -415,9 +425,9 @@ def git_remote_callbacks(payload):
415425
#
416426

417427

418-
def libgit2_callback(f):
428+
def libgit2_callback(f: Callable[..., int]):
419429
@wraps(f)
420-
def wrapper(*args):
430+
def wrapper(*args: Any) -> int:
421431
data = ffi.from_handle(args[-1])
422432
args = args[:-1] + (data,)
423433
try:
@@ -436,10 +446,10 @@ def wrapper(*args):
436446
return ffi.def_extern()(wrapper)
437447

438448

439-
def libgit2_callback_void(f):
449+
def libgit2_callback_void(f: Callable[..., object]):
440450
@wraps(f)
441-
def wrapper(*args):
442-
data = ffi.from_handle(args[-1])
451+
def wrapper(*args: Any):
452+
data: Payload = ffi.from_handle(args[-1])
443453
args = args[:-1] + (data,)
444454
try:
445455
f(*args)
@@ -457,7 +467,7 @@ def wrapper(*args):
457467

458468

459469
@libgit2_callback
460-
def _certificate_check_cb(cert_i, valid, host, data):
470+
def _certificate_check_cb(cert_i, valid: int, host, data: RemoteCallbacks):
461471
# We want to simulate what should happen if libgit2 supported pass-through
462472
# for this callback. For SSH, 'valid' is always False, because it doesn't
463473
# look at known_hosts, but we do want to let it through in order to do what
@@ -479,7 +489,7 @@ def _certificate_check_cb(cert_i, valid, host, data):
479489

480490

481491
@libgit2_callback
482-
def _credentials_cb(cred_out, url, username, allowed, data):
492+
def _credentials_cb(cred_out, url, username, allowed: int, data: RemoteCallbacks):
483493
credentials = getattr(data, 'credentials', None)
484494
if not credentials:
485495
return 0
@@ -493,7 +503,7 @@ def _credentials_cb(cred_out, url, username, allowed, data):
493503

494504

495505
@libgit2_callback
496-
def _push_update_reference_cb(ref, msg, data):
506+
def _push_update_reference_cb(ref, msg, data: RemoteCallbacks):
497507
push_update_reference = getattr(data, 'push_update_reference', None)
498508
if not push_update_reference:
499509
return 0
@@ -519,7 +529,7 @@ def _remote_create_cb(remote_out, repo, name, url, data):
519529

520530

521531
@libgit2_callback
522-
def _repository_create_cb(repo_out, path, bare, data):
532+
def _repository_create_cb(repo_out, path, bare: int, data: RemoteCallbacks):
523533
repository = data.repository(ffi.string(path), bare != 0)
524534
# we no longer own the C object
525535
repository._disown()
@@ -529,7 +539,7 @@ def _repository_create_cb(repo_out, path, bare, data):
529539

530540

531541
@libgit2_callback
532-
def _sideband_progress_cb(string, length, data):
542+
def _sideband_progress_cb(string, length: int, data: RemoteCallbacks):
533543
sideband_progress = getattr(data, 'sideband_progress', None)
534544
if not sideband_progress:
535545
return 0
@@ -540,7 +550,7 @@ def _sideband_progress_cb(string, length, data):
540550

541551

542552
@libgit2_callback
543-
def _transfer_progress_cb(stats_ptr, data):
553+
def _transfer_progress_cb(stats_ptr, data: RemoteCallbacks):
544554
from .remotes import TransferProgress
545555

546556
transfer_progress = getattr(data, 'transfer_progress', None)
@@ -552,7 +562,7 @@ def _transfer_progress_cb(stats_ptr, data):
552562

553563

554564
@libgit2_callback
555-
def _update_tips_cb(refname, a, b, data):
565+
def _update_tips_cb(refname, a, b, data: RemoteCallbacks):
556566
update_tips = getattr(data, 'update_tips', None)
557567
if not update_tips:
558568
return 0
@@ -569,7 +579,7 @@ def _update_tips_cb(refname, a, b, data):
569579
#
570580

571581

572-
def get_credentials(fn, url, username, allowed):
582+
def get_credentials(fn, url, username, allowed: CredentialType):
573583
"""Call fn and return the credentials object."""
574584
url_str = maybe_string(url)
575585
username_str = maybe_string(username)
@@ -633,7 +643,7 @@ def get_credentials(fn, url, username, allowed):
633643

634644
@libgit2_callback
635645
def _checkout_notify_cb(
636-
why, path_cstr, baseline, target, workdir, data: CheckoutCallbacks
646+
why: int, path_cstr, baseline, target, workdir, data: CheckoutCallbacks
637647
):
638648
pypath = maybe_string(path_cstr)
639649
pybaseline = DiffFile.from_c(ptr_to_bytes(baseline))
@@ -660,8 +670,8 @@ def _checkout_progress_cb(path, completed_steps, total_steps, data: CheckoutCall
660670

661671

662672
def _git_checkout_options(
663-
callbacks=None,
664-
strategy=None,
673+
callbacks: CheckoutCallbacks | None = None,
674+
strategy: CheckoutStrategy | None = None,
665675
directory=None,
666676
paths=None,
667677
c_checkout_options_ptr=None,
@@ -693,7 +703,7 @@ def _git_checkout_options(
693703

694704
if paths:
695705
strarray = StrArray(paths)
696-
refs.append(strarray)
706+
refs.append(strarray) # type: ignore
697707
opts.paths = strarray.ptr[0]
698708

699709
# If we want to receive any notifications, set up notify_cb in the options
@@ -717,7 +727,12 @@ def _git_checkout_options(
717727

718728

719729
@contextmanager
720-
def git_checkout_options(callbacks=None, strategy=None, directory=None, paths=None):
730+
def git_checkout_options(
731+
callbacks: CheckoutCallbacks | None = None,
732+
strategy=None,
733+
directory=None,
734+
paths=None,
735+
):
721736
yield _git_checkout_options(
722737
callbacks=callbacks, strategy=strategy, directory=directory, paths=paths
723738
)
@@ -746,7 +761,11 @@ def _stash_apply_progress_cb(progress: StashApplyProgress, data: StashApplyCallb
746761

747762
@contextmanager
748763
def git_stash_apply_options(
749-
callbacks=None, reinstate_index=False, strategy=None, directory=None, paths=None
764+
callbacks: StashApplyCallbacks | None = None,
765+
reinstate_index: bool = False,
766+
strategy=None,
767+
directory=None,
768+
paths=None,
750769
):
751770
if callbacks is None:
752771
callbacks = StashApplyCallbacks()

0 commit comments

Comments
 (0)