@@ -102,108 +102,79 @@ def local_subtensor_of_dot(fgraph, node):
102102 return [r ]
103103
104104
105- # fast_compile to allow opt subtensor(cast{float32}(make_vector) )
106- @register_canonicalize ( "fast_compile " )
105+ @ register_canonicalize ( "shape_unsafe" )
106+ @register_specialize ( "shape_unsafe " )
107107@node_rewriter ([Subtensor ])
108- def local_subtensor_lift (fgraph , node ):
108+ def local_subtensor_of_elemwise (fgraph , node ):
109+ """Lift a Subtensor through an Elemwise and its implicit broadcasting behavior.
110+
111+ exp(x)[:, 0] -> exp(x[:, 0])
112+ add(x, y)[0] -> add(x[0], y[0])
113+ add(x[None], y)[2] -> add(x, y[2])
109114 """
110- unary(x)[idx] -> unary(x[ idx])#any broadcast pattern.
115+ elem , * idx = node . inputs
111116
112- Handles the following unary ops:
113- elemwise(x,...)[idx] -> elemwise(x[idx],...)
114- when x,... are broadcasted scalar or not broadcasted at all
115- Unbroadcast(x)[idx] => Unbroadcast(x[idx])
117+ if not (elem .owner and isinstance (elem .owner .op , Elemwise )):
118+ return None
116119
117- """
118- if isinstance (node .op , Subtensor ):
119- u = node .inputs [0 ]
120- if u .owner is None or len (fgraph .clients [u ]) > 1 :
121- return False
122-
123- if isinstance (u .owner .op , Elemwise ) and len (u .owner .inputs ) == 1 :
124- idx = node .inputs [1 :]
125- x_idx = node .op (u .owner .inputs [0 ], * idx )
126- # Copy over previous output stacktrace
127- copy_stack_trace (node .outputs , x_idx )
128- ret = u .owner .op (x_idx )
129- # Copy over previous output stacktrace
130- # and stacktrace from previous unary operation
131- copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], ret )
132- return [ret ]
120+ if len (fgraph .clients [elem ]) > 1 :
121+ # Elemwise output is used beyond the Subtensor.
122+ # Get out to avoid repeated computations
123+ return None
133124
134- if isinstance (u .owner .op , Elemwise ):
135- new_inputs = []
136- if all (sum (i .type .broadcastable ) == 0 for i in u .owner .inputs ):
137- # There is no broadcastable in the inputs
138- idx = node .inputs [1 :]
139- new_inputs = [node .op (i , * idx ) for i in u .owner .inputs ]
140- # Copy over previous output stacktrace
141- copy_stack_trace (node .outputs [0 ], new_inputs )
142-
143- ret = u .owner .op (* new_inputs )
144- # Copy over previous output stacktrace
145- # and stacktrace from previous unary operation
146- copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], ret )
147- return [ret ]
148- elif all (sum (i .type .broadcastable ) in [i .ndim , 0 ] for i in u .owner .inputs ):
149- # There is no broadcastable in the inputs or it is scalar
150- idx = node .inputs [1 :]
151- new_inputs = []
152- for i in u .owner .inputs :
153- if sum (i .type .broadcastable ) == 0 :
154- new_inputs .append (node .op (i , * idx ))
155- else :
156- # If the subtensor remove some dims, we must
157- # lower the number of dimensions of this scalar.
158- if node .outputs [0 ].ndim == i .ndim :
159- new_inputs .append (i )
160- else :
161- new_inputs .append (
162- i .dimshuffle (["x" ] * node .outputs [0 ].ndim )
163- )
164-
165- # Copy over previous output stacktrace
166- copy_stack_trace (node .outputs [0 ], new_inputs )
167-
168- ret = u .owner .op (* new_inputs )
169- # Copy over previous output stacktrace
170- # and stacktrace from previous unary operation
171- copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], ret )
172- return [ret ]
173-
174- if isinstance (u .owner .op , Unbroadcast ):
175- # Subtensor might reduce dim., adapt broadcast pattern accordingly
176- old_axes = u .owner .op .axes
177- new_axes = []
178-
179- # loop through indices being subtensor-ed
180- # i indexes broadcastable pattern before subtensor
181- # j indexes broadcastable pattern after subtensor
182- j = 0
183- for i , x in enumerate (node .op .idx_list ):
184- # if it is not a slice, it will reduce the dimension, should
185- # not appear in the broascastable dimensions
186- if isinstance (x , slice ):
187- if i in old_axes :
188- new_axes .append (j )
189- j += 1
190- # now keep the broadcastable pattern of all
191- # items not appearing in subtensor list
192- for i in range (len (node .op .idx_list ), len (u .broadcastable )):
193- if i in old_axes :
194- new_axes .append (j )
195- j += 1
125+ idx_tuple = indices_from_subtensor (idx , node .op .idx_list )
196126
197- subt_x = node .op (u .owner .inputs [0 ], * node .inputs [1 :])
198- # Copy over previous output stacktrace
199- copy_stack_trace (node .outputs [0 ], subt_x )
127+ elem_inputs = elem .owner .inputs
128+ elem_bcast = elem .type .broadcastable
129+ if all (inp .type .broadcastable == elem_bcast for inp in elem_inputs ):
130+ # No need to worry about implicit broadcasting.
131+ indexed_inputs = [inp [idx_tuple ] for inp in elem_inputs ]
132+
133+ else :
134+ # The original indices may not make sense on some of the broadcasted dimensions
135+ new_idxs = [list (idx_tuple ) for _ in elem_inputs ]
136+ for dim , (dim_idx , dim_bcast_out , * dim_bcast_inputs ) in enumerate (
137+ zip (
138+ idx_tuple ,
139+ elem_bcast ,
140+ * (inp .type .broadcastable for inp in elem_inputs ),
141+ # Indices can be shorter than input ndims
142+ strict = False ,
143+ )
144+ ):
145+ if is_full_slice (dim_idx ):
146+ # Full slice can be safely applied to all inputs
147+ continue
200148
201- rbcast_subt_x = unbroadcast (subt_x , * new_axes )
202- # Copy over previous output stacktrace
203- # and stacktrace from previous unary operation
204- copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], rbcast_subt_x )
149+ if all (dim_bcast_inp == elem_bcast for dim_bcast_inp in dim_bcast_inputs ):
150+ # This dim is not broadcasted for any of the inputs, original index can be applied to all inputs
151+ continue
152+
153+ # Some dims are broadcasted, so we need to adapt their indices
154+ # Slice indexing keeps the dimension, so we use a full slice for broadcasted inputs
155+ # Integer indexing drops the dimension, so we index by zero for the broadcsated inputs
156+ safe_bcast_dim_idx = slice (None ) if isinstance (dim_idx , slice ) else 0
157+ for inp_idx , dim_bcast_inp in zip (new_idxs , dim_bcast_inputs , strict = True ):
158+ if dim_bcast_inp :
159+ inp_idx [dim ] = safe_bcast_dim_idx
205160
206- return [rbcast_subt_x ]
161+ indexed_inputs = [
162+ inp [tuple (new_idx )]
163+ for inp , new_idx in zip (elem_inputs , new_idxs , strict = True )
164+ ]
165+
166+ [old_out ] = node .outputs
167+
168+ # Copy stack trace to new inputs
169+ [copy_stack_trace (old_out , new_inp ) for new_inp in indexed_inputs ]
170+
171+ # Define elemwise operation on indexed inputs
172+ new_out = elem .owner .op (* indexed_inputs )
173+
174+ # Copy stack trace to new output
175+ copy_stack_trace ([old_out , * node .inputs ], new_out )
176+
177+ return [new_out ]
207178
208179
209180@register_canonicalize ("shape_unsafe" )
@@ -328,6 +299,51 @@ def local_subtensor_of_transpose(fgraph, node):
328299 return [new_out ]
329300
330301
302+ @register_canonicalize ("fast_compile" )
303+ @node_rewriter ([Subtensor ])
304+ def local_subtensor_of_unbroadcast (fgraph , node ):
305+ """
306+ Unbroadcast(x)[idx] => Unbroadcast(x[idx])
307+ """
308+ u = node .inputs [0 ]
309+ if u .owner is None or len (fgraph .clients [u ]) > 1 :
310+ return False
311+
312+ if isinstance (u .owner .op , Unbroadcast ):
313+ # Subtensor might reduce dim., adapt broadcast pattern accordingly
314+ old_axes = u .owner .op .axes
315+ new_axes = []
316+
317+ # loop through indices being subtensor-ed
318+ # i indexes broadcastable pattern before subtensor
319+ # j indexes broadcastable pattern after subtensor
320+ j = 0
321+ for i , x in enumerate (node .op .idx_list ):
322+ # if it is not a slice, it will reduce the dimension, should
323+ # not appear in the broascastable dimensions
324+ if isinstance (x , slice ):
325+ if i in old_axes :
326+ new_axes .append (j )
327+ j += 1
328+ # now keep the broadcastable pattern of all
329+ # items not appearing in subtensor list
330+ for i in range (len (node .op .idx_list ), len (u .broadcastable )):
331+ if i in old_axes :
332+ new_axes .append (j )
333+ j += 1
334+
335+ subt_x = node .op (u .owner .inputs [0 ], * node .inputs [1 :])
336+ # Copy over previous output stacktrace
337+ copy_stack_trace (node .outputs [0 ], subt_x )
338+
339+ rbcast_subt_x = unbroadcast (subt_x , * new_axes )
340+ # Copy over previous output stacktrace
341+ # and stacktrace from previous unary operation
342+ copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], rbcast_subt_x )
343+
344+ return [rbcast_subt_x ]
345+
346+
331347@register_infer_shape
332348@register_useless
333349@register_canonicalize
0 commit comments