Skip to content

Commit 5655ce1

Browse files
authored
Merge pull request #303 from arogozhnikov/dev
Allow anonymous axes in parse_shape, fix #302
2 parents d495e7c + 5b41d90 commit 5655ce1

File tree

3 files changed

+47
-149
lines changed

3 files changed

+47
-149
lines changed

einops/einops.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -687,9 +687,19 @@ def parse_shape(x, pattern: str) -> dict:
687687
else:
688688
composition = exp.composition
689689
result = {}
690-
for (axis_name,), axis_length in zip(composition, shape): # type: ignore
691-
if axis_name != "_":
692-
result[axis_name] = axis_length
690+
for axes, axis_length in zip(composition, shape): # type: ignore
691+
# axes either [], or [AnonymousAxis] or ['axis_name']
692+
if len(axes) == 0:
693+
if axis_length != 1:
694+
raise RuntimeError(f"Length of axis is not 1: {pattern} {shape}")
695+
else:
696+
[axis] = axes
697+
if isinstance(axis, str):
698+
if axis != "_":
699+
result[axis] = axis_length
700+
else:
701+
if axis.value != axis_length:
702+
raise RuntimeError(f"Length of anonymous axis does not match: {pattern} {shape}")
693703
return result
694704

695705

einops/experimental/data_api_packing.py

Lines changed: 0 additions & 137 deletions
This file was deleted.

tests/test_other.py

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,9 @@ def test_repeating(self):
121121
with pytest.raises(einops.EinopsError):
122122
parse_shape(self.backend.from_numpy(self.x), "a a b b")
123123

124-
@parameterized.expand(
125-
[
124+
125+
def test_ellipsis(self):
126+
for shape, pattern, expected in [
126127
([10, 20], "...", dict()),
127128
([10], "... a", dict(a=10)),
128129
([10, 20], "... a", dict(a=20)),
@@ -134,13 +135,37 @@ def test_repeating(self):
134135
([10, 20, 30, 40], "a ...", dict(a=10)),
135136
([10, 20, 30, 40], " a ... b", dict(a=10, b=40)),
136137
([10, 40], " a ... b", dict(a=10, b=40)),
137-
]
138-
)
139-
def test_ellipsis(self, shape: List[int], pattern: str, expected: Dict[str, int]):
140-
x = numpy.ones(shape)
141-
parsed1 = parse_shape(x, pattern)
142-
parsed2 = parse_shape(self.backend.from_numpy(x), pattern)
143-
assert parsed1 == parsed2 == expected
138+
]:
139+
x = numpy.ones(shape)
140+
parsed1 = parse_shape(x, pattern)
141+
parsed2 = parse_shape(self.backend.from_numpy(x), pattern)
142+
assert parsed1 == parsed2 == expected
143+
144+
def test_parse_with_anonymous_axes(self):
145+
for shape, pattern, expected in [
146+
([1, 2, 3, 4], "1 2 3 a", dict(a=4)),
147+
([10, 1, 2], "a 1 2", dict(a=10)),
148+
([10, 1, 2], "a () 2", dict(a=10)),
149+
]:
150+
x = numpy.ones(shape)
151+
parsed1 = parse_shape(x, pattern)
152+
parsed2 = parse_shape(self.backend.from_numpy(x), pattern)
153+
assert parsed1 == parsed2 == expected
154+
155+
156+
def test_failures(self):
157+
# every test should fail
158+
for shape, pattern in [
159+
([1, 2, 3, 4], "a b c"),
160+
([1, 2, 3, 4], "2 a b c"),
161+
([1, 2, 3, 4], "a b c ()"),
162+
([1, 2, 3, 4], "a b c d e"),
163+
([1, 2, 3, 4], "a b c d e ..."),
164+
([1, 2, 3, 4], "a b c ()"),
165+
]:
166+
with pytest.raises(RuntimeError):
167+
x = numpy.ones(shape)
168+
parse_shape(self.backend.from_numpy(x), pattern)
144169

145170

146171
_SYMBOLIC_BACKENDS = [

0 commit comments

Comments
 (0)