Skip to content

Commit d59799f

Browse files
Uri Grantauri-granta
authored andcommitted
Support check_shapes with partial functions
1 parent 14b1c09 commit d59799f

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

check_shapes/error_contexts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,7 @@ class FunctionDefinitionContext(ErrorContext):
323323
func: Callable[..., Any]
324324

325325
def print(self, builder: MessageBuilder) -> None:
326-
name = self.func.__qualname__
326+
name = getattr(self.func, "__qualname__", repr(self.func))
327327
try:
328328
path = inspect.getsourcefile(self.func)
329329
except Exception: # pragma: no cover

tests/check_shapes/test_decorator.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# pylint: disable=unused-argument # Bunch of fake functions below has unused arguments.
1616

1717
from dataclasses import dataclass
18+
from functools import partial
1819
from typing import Mapping, Optional, Sequence, Tuple
1920

2021
import pytest
@@ -779,3 +780,18 @@ def f(a: TestShaped, b: TestShaped) -> TestShaped:
779780
"""
780781
== f.__doc__
781782
)
783+
784+
785+
def test_check_shapes__supports_partial() -> None:
786+
def f(a: TestShaped, d: int) -> TestShaped:
787+
return t(3, d)
788+
789+
shape_check = check_shapes(
790+
"a: [10]",
791+
"return: [3, 10]",
792+
)
793+
partial_f = shape_check(partial(f, d=10))
794+
795+
partial_f(t(10))
796+
with pytest.raises(ShapeMismatchError):
797+
partial_f(t(5))

0 commit comments

Comments
 (0)