Skip to content

Commit 9e12341

Browse files
QwlouseThe kauldron Authors
authored andcommitted
several minor fixes
PiperOrigin-RevId: 869597403
1 parent 66af85d commit 9e12341

File tree

14 files changed

+66
-10
lines changed

14 files changed

+66
-10
lines changed

kauldron/ktyping/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ class PointCloud:
144144
color: UInt8["n 3"]
145145

146146
# Fails because n from pos is different than n from color:
147-
p = PointCloud(pos=np.zeros((17, 3)), color=np.ones((16, 3, dtype=np.uint8)))
147+
p = PointCloud(pos=np.zeros((17, 3)), color=np.ones((16, 3), dtype=np.uint8))
148148
```
149149

150150
Decorating any `dataclass` with `@kt.typechecked` has two effects:

kauldron/ktyping/array_type_meta.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __getitem__(
131131
case (array_types, str(shape_spec), dtype):
132132
# Long form: Array[np.ndarray, "a b", np.float32]
133133
name = f"{cls.__name__}[{array_types!r}, {shape_spec!r}, {dtype!r}]"
134-
if isinstance(array_types, tuple):
134+
if not isinstance(array_types, tuple):
135135
array_types = (array_types,)
136136
case _:
137137
name = cls.__name__

kauldron/ktyping/array_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def _create_x_type(name, dtype=MISSING):
135135

136136

137137
# Generic array type.
138-
XArray = _create_x_type("XArray", None)
138+
XArray = _create_x_type("XArray")
139139

140140
# Any numerical type: int, float, complex, etc. but NOT bool.
141141
# see also https://numpy.org/doc/2.1/reference/arrays.scalars.html#scalars

kauldron/ktyping/array_types_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,3 +217,9 @@ def test_shape():
217217
def test_shape_call_raises():
218218
with pytest.raises(RuntimeError, match="cannot be instantiated"):
219219
art.Shape("b n")
220+
221+
222+
def test_xarray_accepts_any_dtype():
223+
for dtype in (np.float32, np.int32, np.bool_, np.complex64, np.uint8):
224+
assert art.XArray.dtype_matches(np.empty((), dtype=dtype))
225+
assert isinstance(np.zeros((2, 3), dtype=dtype), art.XArray)

kauldron/ktyping/config.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def add_config_override(module_regex: str, config: Config) -> uuid.UUID:
9494
9595
Args:
9696
module_regex: A regex that matches the module name of the source code. This
97-
is match against the module name of typecked function or class with
97+
is matched against the module name of typecked function or class with
9898
`re.match` (so the match is anchored at the start but not the end of the
9999
module name)
100100
config: The config to override the default config with. Only fields that are
@@ -104,6 +104,7 @@ def add_config_override(module_regex: str, config: Config) -> uuid.UUID:
104104
The ID of the config override. This can be used to remove the override
105105
using `kt.remove_config_override`.
106106
"""
107+
107108
cfg_id = uuid.uuid4()
108109
CONFIG_OVERRIDES[cfg_id] = (module_regex, config)
109110
return cfg_id

kauldron/ktyping/decorator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def _typechecked_generator_wrapper(*args, **kwargs):
256256
# Check argument types
257257
for argname, (value, annot) in annotated_args.items():
258258
check.assert_not_never(gen_fn, annot)
259-
print(f"Checking argument {argname} = {value} with annot {annot}")
259+
260260
try:
261261
check.check_type_internal(value, annot)
262262
except errors.TypeCheckError as exc:

kauldron/ktyping/dim_view.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,7 @@ def __contains__(self, name: str) -> bool:
9898
"""Returns True if the dimension is defined in all candidates."""
9999
__ktyping_ignore_frame__ = True # pylint: disable=unused-variable
100100

101+
_, name = _get_dim_type(name)
101102
values = {alt.get(name, MISSING) for alt in self._scope.candidates}
102103
if MISSING in values:
103104
return False
@@ -230,7 +231,9 @@ def _format_dim_assignment(dim_name: str, value: DimValue, align: int) -> str:
230231
def _format_ambiguous_dim(
231232
dim_name: str, values: list[DimValue], align: int
232233
) -> str:
233-
is_tuple = any(len(v) > 1 for v in values)
234+
"""Formats a dimension with multiple ambiguous candidate values."""
235+
defined_values = [v for v in values if v is not MISSING]
236+
is_tuple = any(len(v) > 1 for v in defined_values)
234237

235238
if is_tuple:
236239
dim_name = f"*{dim_name}".rjust(align)

kauldron/ktyping/errors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,6 @@ def shape_error_message(
245245
)
246246
else:
247247
return (
248-
f"has shape {value.shape} which is not shape-compatible with any of"
248+
f"has shape {value.shape} which is not shape-compatible with any of"
249249
f" {acceptable_shapes!r}{array_spec_str}."
250250
)

kauldron/ktyping/scope.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
self.arguments = arguments if arguments is not None else {}
8686
self.annotations = annotations if annotations is not None else {}
8787
self.fstring_locals = fstring_locals if fstring_locals is not None else {}
88-
self.default_args = default_args or set()
88+
self.default_args = default_args if default_args is not None else ()
8989
self.return_value = MISSING
9090

9191
self._check_for_jaxtyping_annotations()

kauldron/ktyping/scope_test.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -195,3 +195,23 @@ def test_dim():
195195
with kt.ShapeScope():
196196
kt.dim["a"] = 1
197197
assert kt.dim["a"] == 1
198+
199+
200+
def test_dim_contains_with_prefix():
201+
with kt.ShapeScope(candidates=[{"b": (8, 16)}]) as scope:
202+
assert "b" in scope.dim
203+
assert "*b" in scope.dim
204+
assert "c" not in scope.dim
205+
assert "*c" not in scope.dim
206+
207+
with kt.ShapeScope(candidates=[{"a": (5,)}]) as scope:
208+
assert "a" in scope.dim
209+
assert "*a" in scope.dim
210+
211+
212+
def test_dim_str_with_partially_defined_candidates():
213+
candidate1 = {"a": (1,), "b": (2, 3)}
214+
candidate2 = {"a": (4,)}
215+
with kt.ShapeScope(candidates=[candidate1, candidate2]) as scope:
216+
assert "a" in scope.dim
217+
assert "b" not in scope.dim

0 commit comments

Comments
 (0)