Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion kauldron/ktyping/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ class PointCloud:
color: UInt8["n 3"]

# Fails because n from pos is different than n from color:
p = PointCloud(pos=np.zeros((17, 3)), color=np.ones((16, 3, dtype=np.uint8)))
p = PointCloud(pos=np.zeros((17, 3)), color=np.ones((16, 3), dtype=np.uint8))
```

Decorating any `dataclass` with `@kt.typechecked` has two effects:
Expand Down
2 changes: 1 addition & 1 deletion kauldron/ktyping/array_type_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __getitem__(
case (array_types, str(shape_spec), dtype):
# Long form: Array[np.ndarray, "a b", np.float32]
name = f"{cls.__name__}[{array_types!r}, {shape_spec!r}, {dtype!r}]"
if isinstance(array_types, tuple):
if not isinstance(array_types, tuple):
array_types = (array_types,)
case _:
name = cls.__name__
Expand Down
2 changes: 1 addition & 1 deletion kauldron/ktyping/array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def _create_x_type(name, dtype=MISSING):


# Generic array type.
XArray = _create_x_type("XArray", None)
XArray = _create_x_type("XArray")

# Any numerical type: int, float, complex, etc. but NOT bool.
# see also https://numpy.org/doc/2.1/reference/arrays.scalars.html#scalars
Expand Down
6 changes: 6 additions & 0 deletions kauldron/ktyping/array_types_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,3 +217,9 @@ def test_shape():
def test_shape_call_raises():
with pytest.raises(RuntimeError, match="cannot be instantiated"):
art.Shape("b n")


def test_xarray_accepts_any_dtype():
for dtype in (np.float32, np.int32, np.bool_, np.complex64, np.uint8):
assert art.XArray.dtype_matches(np.empty((), dtype=dtype))
assert isinstance(np.zeros((2, 3), dtype=dtype), art.XArray)
3 changes: 2 additions & 1 deletion kauldron/ktyping/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def add_config_override(module_regex: str, config: Config) -> uuid.UUID:

Args:
module_regex: A regex that matches the module name of the source code. This
is match against the module name of typecked function or class with
is matched against the module name of typecked function or class with
`re.match` (so the match is anchored at the start but not the end of the
module name)
config: The config to override the default config with. Only fields that are
Expand All @@ -104,6 +104,7 @@ def add_config_override(module_regex: str, config: Config) -> uuid.UUID:
The ID of the config override. This can be used to remove the override
using `kt.remove_config_override`.
"""

cfg_id = uuid.uuid4()
CONFIG_OVERRIDES[cfg_id] = (module_regex, config)
return cfg_id
Expand Down
2 changes: 1 addition & 1 deletion kauldron/ktyping/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ def _typechecked_generator_wrapper(*args, **kwargs):
# Check argument types
for argname, (value, annot) in annotated_args.items():
check.assert_not_never(gen_fn, annot)
print(f"Checking argument {argname} = {value} with annot {annot}")

try:
check.check_type_internal(value, annot)
except errors.TypeCheckError as exc:
Expand Down
5 changes: 4 additions & 1 deletion kauldron/ktyping/dim_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def __contains__(self, name: str) -> bool:
"""Returns True if the dimension is defined in all candidates."""
__ktyping_ignore_frame__ = True # pylint: disable=unused-variable

_, name = _get_dim_type(name)
values = {alt.get(name, MISSING) for alt in self._scope.candidates}
if MISSING in values:
return False
Expand Down Expand Up @@ -230,7 +231,9 @@ def _format_dim_assignment(dim_name: str, value: DimValue, align: int) -> str:
def _format_ambiguous_dim(
dim_name: str, values: list[DimValue], align: int
) -> str:
is_tuple = any(len(v) > 1 for v in values)
"""Formats a dimension with multiple ambiguous candidate values."""
defined_values = [v for v in values if v is not MISSING]
is_tuple = any(len(v) > 1 for v in defined_values)

if is_tuple:
dim_name = f"*{dim_name}".rjust(align)
Expand Down
2 changes: 1 addition & 1 deletion kauldron/ktyping/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,6 @@ def shape_error_message(
)
else:
return (
f"has shape {value.shape} which is not shape-compatible with any of"
f"has shape {value.shape} which is not shape-compatible with any of"
f" {acceptable_shapes!r}{array_spec_str}."
)
2 changes: 1 addition & 1 deletion kauldron/ktyping/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(
self.arguments = arguments if arguments is not None else {}
self.annotations = annotations if annotations is not None else {}
self.fstring_locals = fstring_locals if fstring_locals is not None else {}
self.default_args = default_args or set()
self.default_args = default_args if default_args is not None else ()
self.return_value = MISSING

self._check_for_jaxtyping_annotations()
Expand Down
20 changes: 20 additions & 0 deletions kauldron/ktyping/scope_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,3 +195,23 @@ def test_dim():
with kt.ShapeScope():
kt.dim["a"] = 1
assert kt.dim["a"] == 1


def test_dim_contains_with_prefix():
with kt.ShapeScope(candidates=[{"b": (8, 16)}]) as scope:
assert "b" in scope.dim
assert "*b" in scope.dim
assert "c" not in scope.dim
assert "*c" not in scope.dim

with kt.ShapeScope(candidates=[{"a": (5,)}]) as scope:
assert "a" in scope.dim
assert "*a" in scope.dim


def test_dim_str_with_partially_defined_candidates():
candidate1 = {"a": (1,), "b": (2, 3)}
candidate2 = {"a": (4,)}
with kt.ShapeScope(candidates=[candidate1, candidate2]) as scope:
assert "a" in scope.dim
assert "b" not in scope.dim
2 changes: 1 addition & 1 deletion kauldron/ktyping/shape_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,7 @@ def evaluate(self, dim_values: DimValues) -> tuple[int, ...]:
raise ShapeError(f"Cannot evaluate anonymous dimension: {self!r}")

def evaluate_all(self, dim_values: DimValues) -> Iterator[Shape]:
raise StopIteration() # cannot be evaluated / is underconstrained
yield from () # cannot be evaluated / is underconstrained

def get_all_prefix_matches(
self, shape: Shape, dim_values: DimValues
Expand Down
4 changes: 4 additions & 0 deletions kauldron/ktyping/shape_spec_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ def anon_var_dim(self, args: list[Any]) -> shape_spec.AnonDims:
name = str(args[0]) if args else None
return shape_spec.AnonDims(name=name, length=None)

def anon_plus_dim(self, args: list[Any]) -> shape_spec.AnonDims:
name = str(args[0]) if args else None
return shape_spec.AnonDims(name=name, length=(1, None))

def broadcast_int_dim(self, args: list[Any]) -> shape_spec.IntDim:
return shape_spec.IntDim(value=int(args[0]), broadcastable=True)

Expand Down
22 changes: 22 additions & 0 deletions kauldron/ktyping/shape_spec_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@
ChoiceDim(left=NamedDims("b"), right=NamedDims("c")),
),
),
(
"+_foo",
ShapeSpec(AnonDims("foo", length=(1, None))),
),
]


Expand All @@ -106,6 +110,24 @@ def test_shape_parser(spec_str, expected_spec):
assert repr(expected_spec) == spec_str


@pytest.mark.parametrize(
"spec_str",
[
"2*a+b",
"a+b*c",
"(a+b)*c",
"a*b**c",
"(a*b)**c",
"a+b|c",
"-a+b",
"a//b+c",
],
)
def test_repr_roundtrip(spec_str):
parsed_spec = shape_spec_parser.parse(spec_str)
assert repr(parsed_spec) == spec_str


@pytest.mark.parametrize(
"spec_str",
[
Expand Down
2 changes: 1 addition & 1 deletion kauldron/ktyping/shape_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ def _eval_shape(spec_str: str, candidates: CandidateDims) -> Shape:
# }
raise shape_spec.ShapeError(
f"{spec_str!r} is ambiguous under the current set of possible"
" dim_values. Could be one of:/n - "
" dim_values. Could be one of:\n - "
+ "\n - ".join(f"{k!r}" for k in valid_shapes)
)
return valid_shapes.pop() # return the only valid shape
Loading