diff --git a/kauldron/ktyping/README.md b/kauldron/ktyping/README.md index 56e79414..014f02bc 100644 --- a/kauldron/ktyping/README.md +++ b/kauldron/ktyping/README.md @@ -144,7 +144,7 @@ class PointCloud: color: UInt8["n 3"] # Fails because n from pos is different than n from color: -p = PointCloud(pos=np.zeros((17, 3)), color=np.ones((16, 3, dtype=np.uint8))) +p = PointCloud(pos=np.zeros((17, 3)), color=np.ones((16, 3), dtype=np.uint8)) ``` Decorating any `dataclass` with `@kt.typechecked` has two effects: diff --git a/kauldron/ktyping/array_type_meta.py b/kauldron/ktyping/array_type_meta.py index a9ef205d..a149739f 100644 --- a/kauldron/ktyping/array_type_meta.py +++ b/kauldron/ktyping/array_type_meta.py @@ -131,7 +131,7 @@ def __getitem__( case (array_types, str(shape_spec), dtype): # Long form: Array[np.ndarray, "a b", np.float32] name = f"{cls.__name__}[{array_types!r}, {shape_spec!r}, {dtype!r}]" - if isinstance(array_types, tuple): + if not isinstance(array_types, tuple): array_types = (array_types,) case _: name = cls.__name__ diff --git a/kauldron/ktyping/array_types.py b/kauldron/ktyping/array_types.py index e4bd49a7..636fea28 100644 --- a/kauldron/ktyping/array_types.py +++ b/kauldron/ktyping/array_types.py @@ -135,7 +135,7 @@ def _create_x_type(name, dtype=MISSING): # Generic array type. -XArray = _create_x_type("XArray", None) +XArray = _create_x_type("XArray") # Any numerical type: int, float, complex, etc. but NOT bool. # see also https://numpy.org/doc/2.1/reference/arrays.scalars.html#scalars diff --git a/kauldron/ktyping/array_types_test.py b/kauldron/ktyping/array_types_test.py index 771e02bf..8135f23f 100644 --- a/kauldron/ktyping/array_types_test.py +++ b/kauldron/ktyping/array_types_test.py @@ -217,3 +217,9 @@ def test_shape(): def test_shape_call_raises(): with pytest.raises(RuntimeError, match="cannot be instantiated"): art.Shape("b n") + + +def test_xarray_accepts_any_dtype(): + for dtype in (np.float32, np.int32, np.bool_, np.complex64, np.uint8): + assert art.XArray.dtype_matches(np.empty((), dtype=dtype)) + assert isinstance(np.zeros((2, 3), dtype=dtype), art.XArray) diff --git a/kauldron/ktyping/config.py b/kauldron/ktyping/config.py index 3bc406e5..a5f2e32f 100644 --- a/kauldron/ktyping/config.py +++ b/kauldron/ktyping/config.py @@ -94,7 +94,7 @@ def add_config_override(module_regex: str, config: Config) -> uuid.UUID: Args: module_regex: A regex that matches the module name of the source code. This - is match against the module name of typecked function or class with + is matched against the module name of typecked function or class with `re.match` (so the match is anchored at the start but not the end of the module name) config: The config to override the default config with. Only fields that are @@ -104,6 +104,7 @@ def add_config_override(module_regex: str, config: Config) -> uuid.UUID: The ID of the config override. This can be used to remove the override using `kt.remove_config_override`. """ + cfg_id = uuid.uuid4() CONFIG_OVERRIDES[cfg_id] = (module_regex, config) return cfg_id diff --git a/kauldron/ktyping/decorator.py b/kauldron/ktyping/decorator.py index f35ce6c7..111fd8ef 100644 --- a/kauldron/ktyping/decorator.py +++ b/kauldron/ktyping/decorator.py @@ -256,7 +256,7 @@ def _typechecked_generator_wrapper(*args, **kwargs): # Check argument types for argname, (value, annot) in annotated_args.items(): check.assert_not_never(gen_fn, annot) - print(f"Checking argument {argname} = {value} with annot {annot}") + try: check.check_type_internal(value, annot) except errors.TypeCheckError as exc: diff --git a/kauldron/ktyping/dim_view.py b/kauldron/ktyping/dim_view.py index a9d739c2..d72e4512 100644 --- a/kauldron/ktyping/dim_view.py +++ b/kauldron/ktyping/dim_view.py @@ -98,6 +98,7 @@ def __contains__(self, name: str) -> bool: """Returns True if the dimension is defined in all candidates.""" __ktyping_ignore_frame__ = True # pylint: disable=unused-variable + _, name = _get_dim_type(name) values = {alt.get(name, MISSING) for alt in self._scope.candidates} if MISSING in values: return False @@ -230,7 +231,9 @@ def _format_dim_assignment(dim_name: str, value: DimValue, align: int) -> str: def _format_ambiguous_dim( dim_name: str, values: list[DimValue], align: int ) -> str: - is_tuple = any(len(v) > 1 for v in values) + """Formats a dimension with multiple ambiguous candidate values.""" + defined_values = [v for v in values if v is not MISSING] + is_tuple = any(len(v) > 1 for v in defined_values) if is_tuple: dim_name = f"*{dim_name}".rjust(align) diff --git a/kauldron/ktyping/errors.py b/kauldron/ktyping/errors.py index d25109a6..b7886371 100644 --- a/kauldron/ktyping/errors.py +++ b/kauldron/ktyping/errors.py @@ -245,6 +245,6 @@ def shape_error_message( ) else: return ( - f"has shape {value.shape} which is not shape-compatible with any of" + f"has shape {value.shape} which is not shape-compatible with any of" f" {acceptable_shapes!r}{array_spec_str}." ) diff --git a/kauldron/ktyping/scope.py b/kauldron/ktyping/scope.py index f6d7171a..59f20614 100644 --- a/kauldron/ktyping/scope.py +++ b/kauldron/ktyping/scope.py @@ -85,7 +85,7 @@ def __init__( self.arguments = arguments if arguments is not None else {} self.annotations = annotations if annotations is not None else {} self.fstring_locals = fstring_locals if fstring_locals is not None else {} - self.default_args = default_args or set() + self.default_args = default_args if default_args is not None else () self.return_value = MISSING self._check_for_jaxtyping_annotations() diff --git a/kauldron/ktyping/scope_test.py b/kauldron/ktyping/scope_test.py index 870d270e..75c8cb5e 100644 --- a/kauldron/ktyping/scope_test.py +++ b/kauldron/ktyping/scope_test.py @@ -195,3 +195,23 @@ def test_dim(): with kt.ShapeScope(): kt.dim["a"] = 1 assert kt.dim["a"] == 1 + + +def test_dim_contains_with_prefix(): + with kt.ShapeScope(candidates=[{"b": (8, 16)}]) as scope: + assert "b" in scope.dim + assert "*b" in scope.dim + assert "c" not in scope.dim + assert "*c" not in scope.dim + + with kt.ShapeScope(candidates=[{"a": (5,)}]) as scope: + assert "a" in scope.dim + assert "*a" in scope.dim + + +def test_dim_str_with_partially_defined_candidates(): + candidate1 = {"a": (1,), "b": (2, 3)} + candidate2 = {"a": (4,)} + with kt.ShapeScope(candidates=[candidate1, candidate2]) as scope: + assert "a" in scope.dim + assert "b" not in scope.dim diff --git a/kauldron/ktyping/shape_spec.py b/kauldron/ktyping/shape_spec.py index 990d8bea..63a9a848 100644 --- a/kauldron/ktyping/shape_spec.py +++ b/kauldron/ktyping/shape_spec.py @@ -406,7 +406,7 @@ def evaluate(self, dim_values: DimValues) -> tuple[int, ...]: raise ShapeError(f"Cannot evaluate anonymous dimension: {self!r}") def evaluate_all(self, dim_values: DimValues) -> Iterator[Shape]: - raise StopIteration() # cannot be evaluated / is underconstrained + yield from () # cannot be evaluated / is underconstrained def get_all_prefix_matches( self, shape: Shape, dim_values: DimValues diff --git a/kauldron/ktyping/shape_spec_parser.py b/kauldron/ktyping/shape_spec_parser.py index fc638997..8af4bcc1 100644 --- a/kauldron/ktyping/shape_spec_parser.py +++ b/kauldron/ktyping/shape_spec_parser.py @@ -52,6 +52,10 @@ def anon_var_dim(self, args: list[Any]) -> shape_spec.AnonDims: name = str(args[0]) if args else None return shape_spec.AnonDims(name=name, length=None) + def anon_plus_dim(self, args: list[Any]) -> shape_spec.AnonDims: + name = str(args[0]) if args else None + return shape_spec.AnonDims(name=name, length=(1, None)) + def broadcast_int_dim(self, args: list[Any]) -> shape_spec.IntDim: return shape_spec.IntDim(value=int(args[0]), broadcastable=True) diff --git a/kauldron/ktyping/shape_spec_parser_test.py b/kauldron/ktyping/shape_spec_parser_test.py index 5fb63ba5..7c9f3d5d 100644 --- a/kauldron/ktyping/shape_spec_parser_test.py +++ b/kauldron/ktyping/shape_spec_parser_test.py @@ -96,6 +96,10 @@ ChoiceDim(left=NamedDims("b"), right=NamedDims("c")), ), ), + ( + "+_foo", + ShapeSpec(AnonDims("foo", length=(1, None))), + ), ] @@ -106,6 +110,24 @@ def test_shape_parser(spec_str, expected_spec): assert repr(expected_spec) == spec_str +@pytest.mark.parametrize( + "spec_str", + [ + "2*a+b", + "a+b*c", + "(a+b)*c", + "a*b**c", + "(a*b)**c", + "a+b|c", + "-a+b", + "a//b+c", + ], +) +def test_repr_roundtrip(spec_str): + parsed_spec = shape_spec_parser.parse(spec_str) + assert repr(parsed_spec) == spec_str + + @pytest.mark.parametrize( "spec_str", [ diff --git a/kauldron/ktyping/shape_tools.py b/kauldron/ktyping/shape_tools.py index 59b182e5..e44b7dc3 100644 --- a/kauldron/ktyping/shape_tools.py +++ b/kauldron/ktyping/shape_tools.py @@ -96,7 +96,7 @@ def _eval_shape(spec_str: str, candidates: CandidateDims) -> Shape: # } raise shape_spec.ShapeError( f"{spec_str!r} is ambiguous under the current set of possible" - " dim_values. Could be one of:/n - " + " dim_values. Could be one of:\n - " + "\n - ".join(f"{k!r}" for k in valid_shapes) ) return valid_shapes.pop() # return the only valid shape