Skip to content

Commit 103770c

Browse files
QwlouseThe kauldron Authors
authored andcommitted
Update to typeguard 4.4.1 API
PiperOrigin-RevId: 726929133
1 parent 205b34e commit 103770c

File tree

5 files changed

+35
-24
lines changed

5 files changed

+35
-24
lines changed

.github/workflows/pytest_and_autopublish.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ jobs:
3939
# * sweep_utils_test: Depends on kxm
4040
# * lpips_test: Missing VGG weights
4141
# * partial_loader_test: Orbax partial checkpoint loader not yet open-sourced (TODO(epot): Restore)
42-
# * typing tests: Not yet supported due to typeguard version issues.
4342
- name: Run core tests
4443
run: |
4544
pytest -vv -n auto \
@@ -48,9 +47,7 @@ jobs:
4847
--ignore=kauldron/xm/ \
4948
--ignore=kauldron/metrics/lpips_test.py \
5049
--ignore=kauldron/checkpoints/partial_loader_test.py \
51-
--ignore=kauldron/utils/sweep_utils_test.py \
52-
--ignore=kauldron/typing/shape_spec_test.py \
53-
--ignore=kauldron/typing/type_check_test.py
50+
--ignore=kauldron/utils/sweep_utils_test.py
5451
5552
# Auto-publish when version is increased
5653
publish-job:

kauldron/typing/shape_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def _assert_caller_is_typechecked_func() -> None:
8383
if stack[i + 1].function != "_reraise_with_shape_info":
8484
caller_name = stack[i].function
8585
raise AssertionError(
86-
"Dim and Shape not yet supported due to `typeguard` issue."
87-
f" Raised in {caller_name!r}"
86+
"Dim and Shape only work inside of @typechecked functions. But"
87+
f" {caller_name!r} lacks @typechecked."
8888
)
8989

9090

kauldron/typing/type_check.py

Lines changed: 30 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,10 @@ def check_type(
4646
expected_type: Any,
4747
) -> None:
4848
"""Ensure that value matches expected_type, alias for typeguard.check_type."""
49-
if True: # Typeguard not yet supported
50-
return
5149
return typeguard.check_type(value, expected_type)
5250

5351

54-
exc_cls = Exception
55-
class TypeCheckError(exc_cls):
52+
class TypeCheckError(typeguard.TypeCheckError):
5653
"""Indicates a runtime typechecking error from the @typechecked decorator."""
5754

5855
def __init__(
@@ -112,9 +109,6 @@ def _annotation_repr(ann: Any) -> str:
112109

113110
def typechecked(fn):
114111
"""Decorator to enable runtime type-checking and shape-checking."""
115-
if True: # Typeguard not yet supported
116-
return fn
117-
118112
if hasattr(fn, "__wrapped__"):
119113
raise AssertionError("@typechecked should be the innermost decorator")
120114

@@ -126,25 +120,43 @@ def _reraise_with_shape_info(*args, _typecheck: bool = True, **kwargs):
126120
# typchecking disabled globally or locally -> just return fn(...)
127121
return fn(*args, **kwargs)
128122

129-
# Find either the first Python wrapper or the actual function
130-
python_func = inspect.unwrap(fn, stop=lambda f: hasattr(f, "__code__"))
123+
sig = inspect.signature(fn)
124+
bound_args = sig.bind(*args, **kwargs)
131125
# manually reproduce the functionality of typeguard.typechecked, so that
132126
# we get access to the returnvalue of the function
133127
localns = sys._getframe(1).f_locals # pylint: disable=protected-access
134-
memo = typeguard.CallMemo(python_func, localns, args=args, kwargs=kwargs)
128+
globalns = fn.__globals__ # pylint: disable=protected-access
129+
memo = typeguard.TypeCheckMemo(globalns, localns)
135130
retval = _undef
131+
132+
annotations = typing.get_type_hints(
133+
fn,
134+
globalns=globalns,
135+
localns=localns,
136+
include_extras=True,
137+
)
138+
annotated_arguments = {
139+
k: (v, annotations[k])
140+
for k, v in bound_args.arguments.items()
141+
if k in annotations
142+
}
143+
136144
try:
137-
typeguard.check_argument_types(memo)
145+
typeguard._functions.check_argument_types( # pylint: disable=protected-access
146+
fn.__name__, annotated_arguments, memo=memo
147+
)
138148
retval = fn(*args, **kwargs)
139-
typeguard.check_return_type(retval, memo)
149+
if "return" in annotations:
150+
typeguard._functions.check_return_type( # pylint: disable=protected-access
151+
fn.__name__, retval, annotations["return"], memo
152+
)
140153
return retval
141154
except typeguard.TypeCheckError as e:
142155
# Use function signature to construct a complete list of named arguments
143156
sig = inspect.signature(fn)
144157
bound_args = sig.bind(*args, **kwargs)
145158
bound_args.apply_defaults()
146159

147-
annotations = {k: p.annotation for k, p in sig.parameters.items()}
148160
# TODO(klausg): filter the stacktrace to exclude all the typechecking
149161
raise TypeCheckError(
150162
str(e),
@@ -396,7 +408,7 @@ def _custom_dataclass_checker(
396408
dataclass_as_typed_dict.__module__ = origin_type.__module__
397409
values = {k.name: getattr(value, k.name) for k in fields}
398410
try:
399-
return typeguard.check_type(
411+
return typeguard.check_type_internal(
400412
dataclass_as_typed_dict(**values),
401413
dataclass_as_typed_dict,
402414
memo=memo,
@@ -469,3 +481,7 @@ def add_custom_checker_lookup_fn(lookup_fn):
469481
break
470482
else: # prepend
471483
checker_lookup_fns[:0] = [lookup_fn]
484+
485+
486+
add_custom_checker_lookup_fn(_array_spec_checker_lookup)
487+
add_custom_checker_lookup_fn(_dataclass_checker_lookup)

kauldron/typing/type_check_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import jaxtyping as jt
1818
from kauldron.typing import Float, TypeCheckError, typechecked # pylint: disable=g-multiple-import,g-importing-member
19-
from kauldron.typing import type_check
19+
from kauldron.typing import type_check # pylint: disable=g-bad-import-order
2020
import numpy as np
2121
import pytest
2222

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@ dependencies = [
4242
"tfds-nightly", # TODO(klausg): switch back to tensorflow_datasets>=4.9.7
4343
# once released: https://github.com/tensorflow/datasets/commit/d4bfd59863c6cb5b64d043b7cb6ab566e7d92440
4444
"tqdm",
45-
# TODO(klausg): Restore typeguard or switch to something else
46-
# closest match to the internal typeguard
47-
# "typeguard@git+https://github.com/agronholm/typeguard@0dd7f7510b7c694e66a0d17d1d58d185125bad5d",
45+
"typeguard>=4.4.1",
4846
"typing_extensions",
4947
"xmanager",
5048
# lazy deps (should ideally remove those)

0 commit comments

Comments
 (0)