2
2
3
3
from numpy import broadcast_shapes
4
4
5
- reducers = ("sum" , "prod" , "min" , "max" , "std" , "mean" , "var" , "any" , "all" , "slice" )
5
+ lin_alg_funcs = (
6
+ "concat" ,
7
+ "diagonal" ,
8
+ "expand_dims" ,
9
+ "matmul" ,
10
+ "matrix_transpose" ,
11
+ "outer" ,
12
+ "permute_dims" ,
13
+ "squeeze" ,
14
+ "stack" ,
15
+ "tensordot" ,
16
+ "transpose" ,
17
+ "vecdot" ,
18
+ )
19
+ reducers = ("sum" , "prod" , "min" , "max" , "std" , "mean" , "var" , "any" , "all" , "slice" , "count_nonzero" )
6
20
7
21
# All the available constructors and reducers necessary for the (string) expression evaluator
8
22
constructors = (
18
32
"zeros_like" ,
19
33
"ones_like" ,
20
34
"empty_like" ,
35
+ "eye" ,
21
36
)
22
37
# Note that, as reshape is accepted as a method too, it should always come last in the list
23
38
constructors += ("reshape" ,)
@@ -50,6 +65,8 @@ def reduce_shape(shape, axis, keepdims):
50
65
51
66
def slice_shape (shape , slices ):
52
67
"""Infer shape after slicing."""
68
+ if shape is None :
69
+ return None
53
70
result = []
54
71
for dim , sl in zip (shape , slices , strict = False ):
55
72
if isinstance (sl , int ): # indexing removes the axis
@@ -68,11 +85,9 @@ def slice_shape(shape, slices):
68
85
69
86
def elementwise (* args ):
70
87
"""All args must broadcast elementwise."""
71
- shape = args [0 ]
72
- shape = shape if shape is not None else ()
73
- for s in args [1 :]:
74
- shape = broadcast_shapes (shape , s ) if s is not None else shape
75
- return shape
88
+ if None in args :
89
+ return None
90
+ return broadcast_shapes (* args )
76
91
77
92
78
93
# --- Function registry ---
@@ -118,6 +133,9 @@ def visit_Call(self, node): # noqa : C901
118
133
else :
119
134
kwargs [kw .arg ] = self ._lookup_value (kw .value )
120
135
136
+ if func_name in lin_alg_funcs :
137
+ return None # need to implement shape handling for these funcs
138
+
121
139
# ------- handle constructors ---------------
122
140
if func_name in constructors or attr_name == "reshape" :
123
141
# shape kwarg directly provided
@@ -139,7 +157,7 @@ def visit_Call(self, node): # noqa : C901
139
157
if node .args :
140
158
shape_arg = node .args [0 ]
141
159
if isinstance (shape_arg , ast .Tuple ):
142
- shape = tuple (self ._const_or_lookup (e ) for e in shape_arg .elts )
160
+ shape = tuple (self ._lookup_value (e ) for e in shape_arg .elts )
143
161
elif isinstance (shape_arg , ast .Constant ):
144
162
shape = (shape_arg .value ,)
145
163
else :
@@ -149,10 +167,10 @@ def visit_Call(self, node): # noqa : C901
149
167
150
168
# ---- arange ----
151
169
elif func_name == "arange" :
152
- start = self ._const_or_lookup (node .args [0 ]) if node .args else 0
153
- stop = self ._const_or_lookup (node .args [1 ]) if len (node .args ) > 1 else None
154
- step = self ._const_or_lookup (node .args [2 ]) if len (node .args ) > 2 else 1
155
- shape = self ._const_or_lookup (node .args [4 ]) if len (node .args ) > 4 else kwargs .get ("shape" )
170
+ start = self ._lookup_value (node .args [0 ]) if node .args else 0
171
+ stop = self ._lookup_value (node .args [1 ]) if len (node .args ) > 1 else None
172
+ step = self ._lookup_value (node .args [2 ]) if len (node .args ) > 2 else 1
173
+ shape = self ._lookup_value (node .args [4 ]) if len (node .args ) > 4 else kwargs .get ("shape" )
156
174
157
175
if shape is not None :
158
176
return shape if isinstance (shape , tuple ) else (shape ,)
@@ -169,8 +187,8 @@ def visit_Call(self, node): # noqa : C901
169
187
170
188
# ---- linspace ----
171
189
elif func_name == "linspace" :
172
- num = self ._const_or_lookup (node .args [2 ]) if len (node .args ) > 2 else kwargs .get ("num" )
173
- shape = self ._const_or_lookup (node .args [5 ]) if len (node .args ) > 5 else kwargs .get ("shape" )
190
+ num = self ._lookup_value (node .args [2 ]) if len (node .args ) > 2 else kwargs .get ("num" )
191
+ shape = self ._lookup_value (node .args [5 ]) if len (node .args ) > 5 else kwargs .get ("shape" )
174
192
if shape is not None :
175
193
return shape if isinstance (shape , tuple ) else (shape ,)
176
194
if num is not None :
@@ -180,12 +198,16 @@ def visit_Call(self, node): # noqa : C901
180
198
elif func_name == "frombuffer" or func_name == "fromiter" :
181
199
count = kwargs .get ("count" )
182
200
return (count ,) if count else ()
201
+ elif func_name == "eye" :
202
+ N = self ._lookup_value (node .args [0 ])
203
+ M = self ._lookup_value (node .args [1 ]) if len (node .args ) > 1 else kwargs .get ("M" )
204
+ return (N , N ) if M is None else (N , M )
183
205
184
206
elif func_name == "reshape" or attr_name == "reshape" :
185
207
if node .args :
186
208
shape_arg = node .args [- 1 ]
187
209
if isinstance (shape_arg , ast .Tuple ):
188
- return tuple (self ._const_or_lookup (e ) for e in shape_arg .elts )
210
+ return tuple (self ._lookup_value (e ) for e in shape_arg .elts )
189
211
return ()
190
212
191
213
else :
@@ -218,12 +240,13 @@ def visit_Compare(self, node):
218
240
shapes = [self .visit (node .left )] + [self .visit (c ) for c in node .comparators ]
219
241
return elementwise (* shapes )
220
242
243
+ def visit_Constant (self , node ):
244
+ return ()
245
+
221
246
def visit_BinOp (self , node ):
222
247
left = self .visit (node .left )
223
248
right = self .visit (node .right )
224
- left = () if left is None else left
225
- right = () if right is None else right
226
- return broadcast_shapes (left , right )
249
+ return elementwise (left , right )
227
250
228
251
def _eval_slice (self , node ):
229
252
if isinstance (node , ast .Slice ):
@@ -250,15 +273,6 @@ def _lookup_value(self, node):
250
273
else :
251
274
return None
252
275
253
- def _const_or_lookup (self , node ):
254
- """Return constant value or resolve name to scalar from shapes."""
255
- if isinstance (node , ast .Constant ):
256
- return node .value
257
- elif isinstance (node , ast .Name ):
258
- return self .shapes .get (node .id , None )
259
- else :
260
- return None
261
-
262
276
263
277
# --- Public API ---
264
278
def infer_shape (expr , shapes ):
0 commit comments