Skip to content

Commit 984cf6b

Browse files
QwlouseThe kauldron Authors
authored andcommitted
Fix UNKNOWN_DIM formatting in ktyping error messages.
PiperOrigin-RevId: 889752042
1 parent ca1ec84 commit 984cf6b

File tree

4 files changed

+74
-15
lines changed

4 files changed

+74
-15
lines changed

kauldron/ktyping/dim_view.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from kauldron.ktyping import errors
2323
from kauldron.ktyping import internal_typing
2424
from kauldron.ktyping import scope as kscope
25+
from kauldron.ktyping import utils
2526

2627
DimValue = internal_typing.DimValue
2728
DimValues = internal_typing.DimValues
@@ -217,19 +218,7 @@ def __str__(self) -> str:
217218

218219
def _format_dim_value(value: DimValue) -> str:
219220
"""Formats a DimValue tuple into a human-readable string."""
220-
221-
def _fmt(v):
222-
if isinstance(v, int):
223-
return str(v)
224-
if v == UNKNOWN_DIM:
225-
return "#"
226-
return f"&{v}"
227-
228-
str_values = [_fmt(v) for v in value]
229-
if len(value) == 1:
230-
return str_values[0]
231-
else:
232-
return f"({', '.join(str_values)})"
221+
return utils.format_dim_value(value)
233222

234223

235224
def _format_dim_assignment(dim_name: str, value: DimValue, align: int) -> str:

kauldron/ktyping/errors.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,9 @@ def __str__(self):
204204

205205
def _format_dim_assignment(dim, value):
206206
if len(value) == 1:
207-
return f"{dim}: {value[0]}"
207+
return f"{dim}: {utils.format_dim_value(value)}"
208208
else:
209-
return f"*{dim}: {value}"
209+
return f"*{dim}: {utils.format_dim_value(value)}"
210210

211211

212212
# MARK: Error messages

kauldron/ktyping/errors_test.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,3 +61,55 @@ def test_array_type_check():
6161

6262
assert exc.candidates_block == "Dim Assignments:\n - {a: 2, b: 3}"
6363
assert exc.candidates_block in msg
64+
65+
66+
def test_format_dim_assignment_with_unknown_dim():
67+
"""UNKNOWN_DIM should be formatted as '#' in dim assignments."""
68+
from kauldron.ktyping import internal_typing # pylint: disable=g-import-not-at-top
69+
70+
unknown = internal_typing.UNKNOWN_DIM
71+
72+
# Single unknown dim
73+
result = errors._format_dim_assignment("batch", (unknown,))
74+
assert result == "batch: #", result
75+
76+
# Multi-dim with unknown
77+
result = errors._format_dim_assignment("data", (6, unknown))
78+
assert result == "*data: (6, #)", result
79+
80+
# Multi-dim all unknown
81+
result = errors._format_dim_assignment("b", (unknown, unknown))
82+
assert result == "*b: (#, #)", result
83+
84+
# No unknown (sanity check)
85+
result = errors._format_dim_assignment("x", (3, 5))
86+
assert result == "*x: (3, 5)", result
87+
88+
89+
@typechecked
90+
def _broadcastable_fn(
91+
x: Float["*#batch n"],
92+
y: Int["*#batch m"],
93+
) -> Float["n"]:
94+
del y
95+
return x[..., :, 0] # wrong return shape to trigger error
96+
97+
98+
def test_unknown_dim_in_error_message():
99+
"""UNKNOWN_DIM in candidates_block should show '#' not the enum repr."""
100+
x = np.zeros((1, 6, 4), dtype=np.float32) # #*batch=(1, 6) -> (#, 6)
101+
y = np.zeros((1, 6, 3), dtype=np.int32)
102+
103+
with pytest.raises(errors.KTypeCheckError) as exc_info:
104+
_broadcastable_fn(x, y)
105+
106+
msg = str(exc_info.value)
107+
# The ugly enum repr should NOT appear in the error message
108+
assert (
109+
"UnknownDim" not in msg
110+
), f"UnknownDim leaked into error message:\n{msg}"
111+
assert (
112+
"UNKNOWN_DIM" not in msg
113+
), f"UNKNOWN_DIM leaked into error message:\n{msg}"
114+
# The '#' formatting should be used instead
115+
assert "#" in exc_info.value.candidates_block

kauldron/ktyping/utils.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,24 @@ def _get_array_type_shorthand(array: Any) -> str:
111111
return get_type_name(array)
112112

113113

114+
# MARK: format_dim_value
115+
def format_dim_value(value: internal_typing.DimValue) -> str:
116+
"""Formats a DimValue tuple into a human-readable string."""
117+
118+
def _fmt(v):
119+
if isinstance(v, int):
120+
return str(v)
121+
if v == internal_typing.UNKNOWN_DIM:
122+
return "#"
123+
return f"&{v}"
124+
125+
str_values = [_fmt(v) for v in value]
126+
if len(value) == 1:
127+
return str_values[0]
128+
else:
129+
return f"({', '.join(str_values)})"
130+
131+
114132
# MARK: get_type_hints
115133
def get_type_hints(fn: Callable[..., Any]) -> dict[str, Any]:
116134
"""Return the type hints for the given function with caching."""

0 commit comments

Comments
 (0)