Skip to content

Commit 60cb2a7

Browse files
QwlouseThe kauldron Authors
authored andcommitted
Update to typeguard 4.4.1 API
PiperOrigin-RevId: 726929133
1 parent 205b34e commit 60cb2a7

File tree

5 files changed

+30
-22
lines changed

5 files changed

+30
-22
lines changed

.github/workflows/pytest_and_autopublish.yml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ jobs:
3939
# * sweep_utils_test: Depends on kxm
4040
# * lpips_test: Missing VGG weights
4141
# * partial_loader_test: Orbax partial checkpoint loader not yet open-sourced (TODO(epot): Restore)
42-
# * typing tests: Not yet supported due to typeguard version issues.
4342
- name: Run core tests
4443
run: |
4544
pytest -vv -n auto \
@@ -48,9 +47,7 @@ jobs:
4847
--ignore=kauldron/xm/ \
4948
--ignore=kauldron/metrics/lpips_test.py \
5049
--ignore=kauldron/checkpoints/partial_loader_test.py \
51-
--ignore=kauldron/utils/sweep_utils_test.py \
52-
--ignore=kauldron/typing/shape_spec_test.py \
53-
--ignore=kauldron/typing/type_check_test.py
50+
--ignore=kauldron/utils/sweep_utils_test.py
5451
5552
# Auto-publish when version is increased
5653
publish-job:

kauldron/typing/shape_spec.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,8 @@ def _assert_caller_is_typechecked_func() -> None:
8383
if stack[i + 1].function != "_reraise_with_shape_info":
8484
caller_name = stack[i].function
8585
raise AssertionError(
86-
"Dim and Shape not yet supported due to `typeguard` issue."
87-
f" Raised in {caller_name!r}"
86+
"Dim and Shape only work inside of @typechecked functions. But"
87+
f" {caller_name!r} lacks @typechecked."
8888
)
8989

9090

kauldron/typing/type_check.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ def check_type(
4646
expected_type: Any,
4747
) -> None:
4848
"""Ensure that value matches expected_type, alias for typeguard.check_type."""
49-
if True: # Typeguard not yet supported
50-
return
5149
return typeguard.check_type(value, expected_type)
5250

5351

54-
exc_cls = Exception
52+
exc_cls = typeguard.TypeCheckError
53+
54+
5555
class TypeCheckError(exc_cls):
5656
"""Indicates a runtime typechecking error from the @typechecked decorator."""
5757

@@ -112,9 +112,6 @@ def _annotation_repr(ann: Any) -> str:
112112

113113
def typechecked(fn):
114114
"""Decorator to enable runtime type-checking and shape-checking."""
115-
if True: # Typeguard not yet supported
116-
return fn
117-
118115
if hasattr(fn, "__wrapped__"):
119116
raise AssertionError("@typechecked should be the innermost decorator")
120117

@@ -126,17 +123,29 @@ def _reraise_with_shape_info(*args, _typecheck: bool = True, **kwargs):
126123
# typchecking disabled globally or locally -> just return fn(...)
127124
return fn(*args, **kwargs)
128125

129-
# Find either the first Python wrapper or the actual function
130-
python_func = inspect.unwrap(fn, stop=lambda f: hasattr(f, "__code__"))
126+
sig = inspect.signature(fn)
127+
bound_args = sig.bind(*args, **kwargs)
131128
# manually reproduce the functionality of typeguard.typechecked, so that
132129
# we get access to the returnvalue of the function
133130
localns = sys._getframe(1).f_locals # pylint: disable=protected-access
134-
memo = typeguard.CallMemo(python_func, localns, args=args, kwargs=kwargs)
131+
globalns = sys._getframe(1).f_globals # pylint: disable=protected-access
132+
memo = typeguard.TypeCheckMemo(globalns, localns)
135133
retval = _undef
134+
annotated_arguments = {
135+
k: (v, sig.parameters[k].annotation)
136+
for k, v in bound_args.arguments.items()
137+
if sig.parameters[k].annotation != inspect.Parameter.empty
138+
}
139+
136140
try:
137-
typeguard.check_argument_types(memo)
141+
typeguard._functions.check_argument_types( # pylint: disable=protected-access
142+
fn.__name__, annotated_arguments, memo=memo
143+
)
138144
retval = fn(*args, **kwargs)
139-
typeguard.check_return_type(retval, memo)
145+
if sig.return_annotation is not inspect.Parameter.empty:
146+
typeguard._functions.check_return_type( # pylint: disable=protected-access
147+
fn.__name__, retval, sig.return_annotation, memo
148+
)
140149
return retval
141150
except typeguard.TypeCheckError as e:
142151
# Use function signature to construct a complete list of named arguments
@@ -396,7 +405,7 @@ def _custom_dataclass_checker(
396405
dataclass_as_typed_dict.__module__ = origin_type.__module__
397406
values = {k.name: getattr(value, k.name) for k in fields}
398407
try:
399-
return typeguard.check_type(
408+
return typeguard.check_type_internal(
400409
dataclass_as_typed_dict(**values),
401410
dataclass_as_typed_dict,
402411
memo=memo,
@@ -469,3 +478,7 @@ def add_custom_checker_lookup_fn(lookup_fn):
469478
break
470479
else: # prepend
471480
checker_lookup_fns[:0] = [lookup_fn]
481+
482+
483+
add_custom_checker_lookup_fn(_array_spec_checker_lookup)
484+
add_custom_checker_lookup_fn(_dataclass_checker_lookup)

kauldron/typing/type_check_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
import jaxtyping as jt
1818
from kauldron.typing import Float, TypeCheckError, typechecked # pylint: disable=g-multiple-import,g-importing-member
19-
from kauldron.typing import type_check
19+
from kauldron.typing import type_check # pylint: disable=g-bad-import-order
2020
import numpy as np
2121
import pytest
2222

pyproject.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,7 @@ dependencies = [
4242
"tfds-nightly", # TODO(klausg): switch back to tensorflow_datasets>=4.9.7
4343
# once released: https://github.com/tensorflow/datasets/commit/d4bfd59863c6cb5b64d043b7cb6ab566e7d92440
4444
"tqdm",
45-
# TODO(klausg): Restore typeguard or switch to something else
46-
# closest match to the internal typeguard
47-
# "typeguard@git+https://github.com/agronholm/typeguard@0dd7f7510b7c694e66a0d17d1d58d185125bad5d",
45+
"typeguard>=4.4.1",
4846
"typing_extensions",
4947
"xmanager",
5048
# lazy deps (should ideally remove those)

0 commit comments

Comments
 (0)