Skip to content

Commit 7424ce8

Browse files
committed
Add shape_utils.py file
1 parent 73aa182 commit 7424ce8

File tree

1 file changed

+267
-0
lines changed

1 file changed

+267
-0
lines changed

src/blosc2/shape_utils.py

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
import ast
2+
3+
from numpy import broadcast_shapes
4+
5+
reducers = ("sum", "prod", "min", "max", "std", "mean", "var", "any", "all", "slice")
6+
7+
# All the available constructors and reducers necessary for the (string) expression evaluator
8+
constructors = (
9+
"arange",
10+
"linspace",
11+
"fromiter",
12+
"zeros",
13+
"ones",
14+
"empty",
15+
"full",
16+
"frombuffer",
17+
"full_like",
18+
"zeros_like",
19+
"ones_like",
20+
"empty_like",
21+
)
22+
# Note that, as reshape is accepted as a method too, it should always come last in the list
23+
constructors += ("reshape",)
24+
25+
26+
# --- Shape utilities ---
27+
def reduce_shape(shape, axis, keepdims):
28+
"""Reduce shape along given axis or axes (collapse dimensions)."""
29+
if shape is None:
30+
return None # unknown shape
31+
32+
# full reduction
33+
if axis is None:
34+
return (1,) * len(shape) if keepdims else ()
35+
36+
# normalize to tuple
37+
if isinstance(axis, int):
38+
axes = (axis,)
39+
else:
40+
axes = tuple(axis)
41+
42+
# normalize negative axes
43+
axes = tuple(a + len(shape) if a < 0 else a for a in axes)
44+
45+
if keepdims:
46+
return tuple(d if i not in axes else 1 for i, d in enumerate(shape))
47+
else:
48+
return tuple(d for i, d in enumerate(shape) if i not in axes)
49+
50+
51+
def slice_shape(shape, slices):
52+
"""Infer shape after slicing."""
53+
result = []
54+
for dim, sl in zip(shape, slices, strict=False):
55+
if isinstance(sl, int): # indexing removes the axis
56+
continue
57+
if isinstance(sl, slice):
58+
start = sl.start or 0
59+
stop = sl.stop if sl.stop is not None else dim
60+
step = sl.step or 1
61+
length = max(0, (stop - start + (step - 1)) // step)
62+
result.append(length)
63+
else:
64+
raise ValueError(f"Unsupported slice type: {sl}")
65+
result.extend(shape[len(slices) :]) # untouched trailing dims
66+
return tuple(result)
67+
68+
69+
def elementwise(*args):
70+
"""All args must broadcast elementwise."""
71+
shape = args[0]
72+
shape = shape if shape is not None else ()
73+
for s in args[1:]:
74+
shape = broadcast_shapes(shape, s) if s is not None else shape
75+
return shape
76+
77+
78+
# --- Function registry ---
79+
FUNCTIONS = { # ignore out arg
80+
func: lambda x, axis=None, keepdims=False, out=None: reduce_shape(x, axis, keepdims)
81+
for func in reducers
82+
# any unknown function will default to elementwise
83+
}
84+
85+
86+
# --- AST Shape Inferencer ---
87+
class ShapeInferencer(ast.NodeVisitor):
88+
def __init__(self, shapes):
89+
self.shapes = shapes
90+
91+
def visit_Name(self, node):
92+
if node.id not in self.shapes:
93+
raise ValueError(f"Unknown symbol: {node.id}")
94+
s = self.shapes[node.id]
95+
if isinstance(s, tuple):
96+
return s
97+
else: # passed a scalar value
98+
return ()
99+
100+
def visit_Call(self, node): # noqa : C901
101+
func_name = getattr(node.func, "id", None)
102+
attr_name = getattr(node.func, "attr", None)
103+
104+
# --- Recursive method-chain support ---
105+
obj_shape = None
106+
if isinstance(node.func, ast.Attribute):
107+
obj_shape = self.visit(node.func.value)
108+
109+
# --- Parse keyword args ---
110+
kwargs = {}
111+
for kw in node.keywords:
112+
if isinstance(kw.value, ast.Constant):
113+
kwargs[kw.arg] = kw.value.value
114+
elif isinstance(kw.value, ast.Tuple):
115+
kwargs[kw.arg] = tuple(
116+
e.value if isinstance(e, ast.Constant) else self._lookup_value(e) for e in kw.value.elts
117+
)
118+
else:
119+
kwargs[kw.arg] = self._lookup_value(kw.value)
120+
121+
# ------- handle constructors ---------------
122+
if func_name in constructors or attr_name == "reshape":
123+
# shape kwarg directly provided
124+
if "shape" in kwargs:
125+
val = kwargs["shape"]
126+
return val if isinstance(val, tuple) else (val,)
127+
128+
# ---- array constructors like zeros, ones, full, etc. ----
129+
elif func_name in (
130+
"zeros",
131+
"ones",
132+
"empty",
133+
"full",
134+
"full_like",
135+
"zeros_like",
136+
"empty_like",
137+
"ones_like",
138+
):
139+
if node.args:
140+
shape_arg = node.args[0]
141+
if isinstance(shape_arg, ast.Tuple):
142+
shape = tuple(self._const_or_lookup(e) for e in shape_arg.elts)
143+
elif isinstance(shape_arg, ast.Constant):
144+
shape = (shape_arg.value,)
145+
else:
146+
shape = self._lookup_value(shape_arg)
147+
shape = shape if isinstance(shape, tuple) else (shape,)
148+
return shape
149+
150+
# ---- arange ----
151+
elif func_name == "arange":
152+
start = self._const_or_lookup(node.args[0]) if node.args else 0
153+
stop = self._const_or_lookup(node.args[1]) if len(node.args) > 1 else None
154+
step = self._const_or_lookup(node.args[2]) if len(node.args) > 2 else 1
155+
shape = self._const_or_lookup(node.args[4]) if len(node.args) > 4 else kwargs.get("shape")
156+
157+
if shape is not None:
158+
return shape if isinstance(shape, tuple) else (shape,)
159+
160+
# Fallback to numeric difference if possible
161+
if stop is None:
162+
stop, start = start, 0
163+
try:
164+
NUM = int((stop - start) / step)
165+
except Exception:
166+
# symbolic or non-numeric: unknown 1D
167+
return ((),)
168+
return (max(NUM, 0),)
169+
170+
# ---- linspace ----
171+
elif func_name == "linspace":
172+
num = self._const_or_lookup(node.args[2]) if len(node.args) > 2 else kwargs.get("num")
173+
shape = self._const_or_lookup(node.args[5]) if len(node.args) > 5 else kwargs.get("shape")
174+
if shape is not None:
175+
return shape if isinstance(shape, tuple) else (shape,)
176+
if num is not None:
177+
return (num,)
178+
raise ValueError("linspace requires either shape or num argument")
179+
180+
elif func_name == "frombuffer" or func_name == "fromiter":
181+
count = kwargs.get("count")
182+
return (count,) if count else ()
183+
184+
elif func_name == "reshape" or attr_name == "reshape":
185+
if node.args:
186+
shape_arg = node.args[-1]
187+
if isinstance(shape_arg, ast.Tuple):
188+
return tuple(self._const_or_lookup(e) for e in shape_arg.elts)
189+
return ()
190+
191+
else:
192+
raise ValueError(f"Unrecognized constructor or missing shape argument for {func_name}")
193+
194+
# --- Special-case .slice((slice(...), ...)) ---
195+
if attr_name == "slice":
196+
if not node.args:
197+
raise ValueError(".slice() requires an argument")
198+
slice_arg = node.args[0]
199+
if isinstance(slice_arg, ast.Tuple):
200+
slices = [self._eval_slice(s) for s in slice_arg.elts]
201+
else:
202+
slices = [self._eval_slice(slice_arg)]
203+
return slice_shape(obj_shape, slices)
204+
205+
# --- Evaluate argument shapes normally ---
206+
args = [self.visit(arg) for arg in node.args]
207+
208+
if func_name in FUNCTIONS:
209+
return FUNCTIONS[func_name](*args, **kwargs)
210+
if attr_name in FUNCTIONS:
211+
return FUNCTIONS[attr_name](obj_shape, **kwargs)
212+
213+
shapes = [obj_shape] + args if obj_shape is not None else args
214+
shapes = [s for s in shapes if s is not None]
215+
return elementwise(*shapes) if shapes else ()
216+
217+
def visit_Compare(self, node):
218+
shapes = [self.visit(node.left)] + [self.visit(c) for c in node.comparators]
219+
return elementwise(*shapes)
220+
221+
def visit_BinOp(self, node):
222+
left = self.visit(node.left)
223+
right = self.visit(node.right)
224+
left = () if left is None else left
225+
right = () if right is None else right
226+
return broadcast_shapes(left, right)
227+
228+
def _eval_slice(self, node):
229+
if isinstance(node, ast.Slice):
230+
return slice(
231+
node.lower.value if node.lower else None,
232+
node.upper.value if node.upper else None,
233+
node.step.value if node.step else None,
234+
)
235+
elif isinstance(node, ast.Call) and getattr(node.func, "id", None) == "slice":
236+
# handle explicit slice() constructor
237+
args = [a.value if isinstance(a, ast.Constant) else None for a in node.args]
238+
return slice(*args)
239+
elif isinstance(node, ast.Constant):
240+
return node.value
241+
else:
242+
raise ValueError(f"Unsupported slice expression: {ast.dump(node)}")
243+
244+
def _lookup_value(self, node):
245+
"""Look up a value in self.shapes if node is a variable name, else constant value."""
246+
if isinstance(node, ast.Name):
247+
return self.shapes.get(node.id, None)
248+
elif isinstance(node, ast.Constant):
249+
return node.value
250+
else:
251+
return None
252+
253+
def _const_or_lookup(self, node):
254+
"""Return constant value or resolve name to scalar from shapes."""
255+
if isinstance(node, ast.Constant):
256+
return node.value
257+
elif isinstance(node, ast.Name):
258+
return self.shapes.get(node.id, None)
259+
else:
260+
return None
261+
262+
263+
# --- Public API ---
264+
def infer_shape(expr, shapes):
265+
tree = ast.parse(expr, mode="eval")
266+
inferencer = ShapeInferencer(shapes)
267+
return inferencer.visit(tree.body)

0 commit comments

Comments
 (0)