@@ -110,108 +110,79 @@ def local_subtensor_of_dot(fgraph, node):
110110 return [r ]
111111
112112
113- # fast_compile to allow opt subtensor(cast{float32}(make_vector) )
114- @register_canonicalize ( "fast_compile " )
113+ @ register_canonicalize ( "shape_unsafe" )
114+ @register_specialize ( "shape_unsafe " )
115115@node_rewriter ([Subtensor ])
116- def local_subtensor_lift (fgraph , node ):
116+ def local_subtensor_of_elemwise (fgraph , node ):
117+ """Lift a Subtensor through an Elemwise and its implicit broadcasting behavior.
118+
119+ exp(x)[:, 0] -> exp(x[:, 0])
120+ add(x, y)[0] -> add(x[0], y[0])
121+ add(x[None], y)[2] -> add(x, y[2])
117122 """
118- unary(x)[idx] -> unary(x[ idx])#any broadcast pattern.
123+ elem , * idx = node . inputs
119124
120- Handles the following unary ops:
121- elemwise(x,...)[idx] -> elemwise(x[idx],...)
122- when x,... are broadcasted scalar or not broadcasted at all
123- Unbroadcast(x)[idx] => Unbroadcast(x[idx])
125+ if not (elem .owner and isinstance (elem .owner .op , Elemwise )):
126+ return None
124127
125- """
126- if isinstance (node .op , Subtensor ):
127- u = node .inputs [0 ]
128- if u .owner is None or len (fgraph .clients [u ]) > 1 :
129- return False
130-
131- if isinstance (u .owner .op , Elemwise ) and len (u .owner .inputs ) == 1 :
132- idx = node .inputs [1 :]
133- x_idx = node .op (u .owner .inputs [0 ], * idx )
134- # Copy over previous output stacktrace
135- copy_stack_trace (node .outputs , x_idx )
136- ret = u .owner .op (x_idx )
137- # Copy over previous output stacktrace
138- # and stacktrace from previous unary operation
139- copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], ret )
140- return [ret ]
128+ if len (fgraph .clients [elem ]) > 1 :
129+ # Elemwise output is used beyond the Subtensor.
130+ # Get out to avoid repeated computations
131+ return None
141132
142- if isinstance (u .owner .op , Elemwise ):
143- new_inputs = []
144- if all (sum (i .type .broadcastable ) == 0 for i in u .owner .inputs ):
145- # There is no broadcastable in the inputs
146- idx = node .inputs [1 :]
147- new_inputs = [node .op (i , * idx ) for i in u .owner .inputs ]
148- # Copy over previous output stacktrace
149- copy_stack_trace (node .outputs [0 ], new_inputs )
150-
151- ret = u .owner .op (* new_inputs )
152- # Copy over previous output stacktrace
153- # and stacktrace from previous unary operation
154- copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], ret )
155- return [ret ]
156- elif all (sum (i .type .broadcastable ) in [i .ndim , 0 ] for i in u .owner .inputs ):
157- # There is no broadcastable in the inputs or it is scalar
158- idx = node .inputs [1 :]
159- new_inputs = []
160- for i in u .owner .inputs :
161- if sum (i .type .broadcastable ) == 0 :
162- new_inputs .append (node .op (i , * idx ))
163- else :
164- # If the subtensor remove some dims, we must
165- # lower the number of dimensions of this scalar.
166- if node .outputs [0 ].ndim == i .ndim :
167- new_inputs .append (i )
168- else :
169- new_inputs .append (
170- i .dimshuffle (["x" ] * node .outputs [0 ].ndim )
171- )
172-
173- # Copy over previous output stacktrace
174- copy_stack_trace (node .outputs [0 ], new_inputs )
175-
176- ret = u .owner .op (* new_inputs )
177- # Copy over previous output stacktrace
178- # and stacktrace from previous unary operation
179- copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], ret )
180- return [ret ]
181-
182- if isinstance (u .owner .op , Unbroadcast ):
183- # Subtensor might reduce dim., adapt broadcast pattern accordingly
184- old_axes = u .owner .op .axes
185- new_axes = []
186-
187- # loop through indices being subtensor-ed
188- # i indexes broadcastable pattern before subtensor
189- # j indexes broadcastable pattern after subtensor
190- j = 0
191- for i , x in enumerate (node .op .idx_list ):
192- # if it is not a slice, it will reduce the dimension, should
193- # not appear in the broascastable dimensions
194- if isinstance (x , slice ):
195- if i in old_axes :
196- new_axes .append (j )
197- j += 1
198- # now keep the broadcastable pattern of all
199- # items not appearing in subtensor list
200- for i in range (len (node .op .idx_list ), len (u .broadcastable )):
201- if i in old_axes :
202- new_axes .append (j )
203- j += 1
133+ idx_tuple = indices_from_subtensor (idx , node .op .idx_list )
204134
205- subt_x = node .op (u .owner .inputs [0 ], * node .inputs [1 :])
206- # Copy over previous output stacktrace
207- copy_stack_trace (node .outputs [0 ], subt_x )
135+ elem_inputs = elem .owner .inputs
136+ elem_bcast = elem .type .broadcastable
137+ if all (inp .type .broadcastable == elem_bcast for inp in elem_inputs ):
138+ # No need to worry about implicit broadcasting.
139+ indexed_inputs = [inp [idx_tuple ] for inp in elem_inputs ]
140+
141+ else :
142+ # The original indices may not make sense on some of the broadcasted dimensions
143+ new_idxs = [list (idx_tuple ) for _ in elem_inputs ]
144+ for dim , (dim_idx , dim_bcast_out , * dim_bcast_inputs ) in enumerate (
145+ zip (
146+ idx_tuple ,
147+ elem_bcast ,
148+ * (inp .type .broadcastable for inp in elem_inputs ),
149+ # Indices can be shorter than input ndims
150+ strict = False ,
151+ )
152+ ):
153+ if is_full_slice (dim_idx ):
154+ # Full slice can be safely applied to all inputs
155+ continue
208156
209- rbcast_subt_x = unbroadcast (subt_x , * new_axes )
210- # Copy over previous output stacktrace
211- # and stacktrace from previous unary operation
212- copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], rbcast_subt_x )
157+ if all (dim_bcast_inp == elem_bcast for dim_bcast_inp in dim_bcast_inputs ):
158+ # This dim is not broadcasted for any of the inputs, original index can be applied to all inputs
159+ continue
160+
161+ # Some dims are broadcasted, so we need to adapt their indices
162+ # Slice indexing keeps the dimension, so we use a full slice for broadcasted inputs
163+ # Integer indexing drops the dimension, so we index by zero for the broadcsated inputs
164+ safe_bcast_dim_idx = slice (None ) if isinstance (dim_idx , slice ) else 0
165+ for inp_idx , dim_bcast_inp in zip (new_idxs , dim_bcast_inputs , strict = True ):
166+ if dim_bcast_inp :
167+ inp_idx [dim ] = safe_bcast_dim_idx
213168
214- return [rbcast_subt_x ]
169+ indexed_inputs = [
170+ inp [tuple (new_idx )]
171+ for inp , new_idx in zip (elem_inputs , new_idxs , strict = True )
172+ ]
173+
174+ [old_out ] = node .outputs
175+
176+ # Copy stack trace to new inputs
177+ [copy_stack_trace (old_out , new_inp ) for new_inp in indexed_inputs ]
178+
179+ # Define elemwise operation on indexed inputs
180+ new_out = elem .owner .op (* indexed_inputs )
181+
182+ # Copy stack trace to new output
183+ copy_stack_trace ([old_out , * node .inputs ], new_out )
184+
185+ return [new_out ]
215186
216187
217188@register_canonicalize ("shape_unsafe" )
@@ -336,6 +307,51 @@ def local_subtensor_of_transpose(fgraph, node):
336307 return [new_out ]
337308
338309
310+ @register_canonicalize ("fast_compile" )
311+ @node_rewriter ([Subtensor ])
312+ def local_subtensor_of_unbroadcast (fgraph , node ):
313+ """
314+ Unbroadcast(x)[idx] => Unbroadcast(x[idx])
315+ """
316+ u = node .inputs [0 ]
317+ if u .owner is None or len (fgraph .clients [u ]) > 1 :
318+ return False
319+
320+ if isinstance (u .owner .op , Unbroadcast ):
321+ # Subtensor might reduce dim., adapt broadcast pattern accordingly
322+ old_axes = u .owner .op .axes
323+ new_axes = []
324+
325+ # loop through indices being subtensor-ed
326+ # i indexes broadcastable pattern before subtensor
327+ # j indexes broadcastable pattern after subtensor
328+ j = 0
329+ for i , x in enumerate (node .op .idx_list ):
330+ # if it is not a slice, it will reduce the dimension, should
331+ # not appear in the broascastable dimensions
332+ if isinstance (x , slice ):
333+ if i in old_axes :
334+ new_axes .append (j )
335+ j += 1
336+ # now keep the broadcastable pattern of all
337+ # items not appearing in subtensor list
338+ for i in range (len (node .op .idx_list ), len (u .broadcastable )):
339+ if i in old_axes :
340+ new_axes .append (j )
341+ j += 1
342+
343+ subt_x = node .op (u .owner .inputs [0 ], * node .inputs [1 :])
344+ # Copy over previous output stacktrace
345+ copy_stack_trace (node .outputs [0 ], subt_x )
346+
347+ rbcast_subt_x = unbroadcast (subt_x , * new_axes )
348+ # Copy over previous output stacktrace
349+ # and stacktrace from previous unary operation
350+ copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], rbcast_subt_x )
351+
352+ return [rbcast_subt_x ]
353+
354+
339355@register_infer_shape
340356@register_useless
341357@register_canonicalize
0 commit comments