@@ -108,108 +108,79 @@ def local_subtensor_of_dot(fgraph, node):
108108 return [r ]
109109
110110
111- # fast_compile to allow opt subtensor(cast{float32}(make_vector) )
112- @register_canonicalize ( "fast_compile " )
111+ @ register_canonicalize ( "shape_unsafe" )
112+ @register_specialize ( "shape_unsafe " )
113113@node_rewriter ([Subtensor ])
114- def local_subtensor_lift (fgraph , node ):
114+ def local_subtensor_of_elemwise (fgraph , node ):
115+ """Lift a Subtensor through an Elemwise and its implicit broadcasting behavior.
116+
117+ exp(x)[:, 0] -> exp(x[:, 0])
118+ add(x, y)[0] -> add(x[0], y[0])
119+ add(x[None], y)[2] -> add(x, y[2])
115120 """
116- unary(x)[idx] -> unary(x[ idx])#any broadcast pattern.
121+ elem , * idx = node . inputs
117122
118- Handles the following unary ops:
119- elemwise(x,...)[idx] -> elemwise(x[idx],...)
120- when x,... are broadcasted scalar or not broadcasted at all
121- Unbroadcast(x)[idx] => Unbroadcast(x[idx])
123+ if not (elem .owner and isinstance (elem .owner .op , Elemwise )):
124+ return None
122125
123- """
124- if isinstance (node .op , Subtensor ):
125- u = node .inputs [0 ]
126- if u .owner is None or len (fgraph .clients [u ]) > 1 :
127- return False
128-
129- if isinstance (u .owner .op , Elemwise ) and len (u .owner .inputs ) == 1 :
130- idx = node .inputs [1 :]
131- x_idx = node .op (u .owner .inputs [0 ], * idx )
132- # Copy over previous output stacktrace
133- copy_stack_trace (node .outputs , x_idx )
134- ret = u .owner .op (x_idx )
135- # Copy over previous output stacktrace
136- # and stacktrace from previous unary operation
137- copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], ret )
138- return [ret ]
126+ if len (fgraph .clients [elem ]) > 1 :
127+ # Elemwise output is used beyond the Subtensor.
128+ # Get out to avoid repeated computations
129+ return None
139130
140- if isinstance (u .owner .op , Elemwise ):
141- new_inputs = []
142- if all (sum (i .type .broadcastable ) == 0 for i in u .owner .inputs ):
143- # There is no broadcastable in the inputs
144- idx = node .inputs [1 :]
145- new_inputs = [node .op (i , * idx ) for i in u .owner .inputs ]
146- # Copy over previous output stacktrace
147- copy_stack_trace (node .outputs [0 ], new_inputs )
148-
149- ret = u .owner .op (* new_inputs )
150- # Copy over previous output stacktrace
151- # and stacktrace from previous unary operation
152- copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], ret )
153- return [ret ]
154- elif all (sum (i .type .broadcastable ) in [i .ndim , 0 ] for i in u .owner .inputs ):
155- # There is no broadcastable in the inputs or it is scalar
156- idx = node .inputs [1 :]
157- new_inputs = []
158- for i in u .owner .inputs :
159- if sum (i .type .broadcastable ) == 0 :
160- new_inputs .append (node .op (i , * idx ))
161- else :
162- # If the subtensor remove some dims, we must
163- # lower the number of dimensions of this scalar.
164- if node .outputs [0 ].ndim == i .ndim :
165- new_inputs .append (i )
166- else :
167- new_inputs .append (
168- i .dimshuffle (["x" ] * node .outputs [0 ].ndim )
169- )
170-
171- # Copy over previous output stacktrace
172- copy_stack_trace (node .outputs [0 ], new_inputs )
173-
174- ret = u .owner .op (* new_inputs )
175- # Copy over previous output stacktrace
176- # and stacktrace from previous unary operation
177- copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], ret )
178- return [ret ]
179-
180- if isinstance (u .owner .op , Unbroadcast ):
181- # Subtensor might reduce dim., adapt broadcast pattern accordingly
182- old_axes = u .owner .op .axes
183- new_axes = []
184-
185- # loop through indices being subtensor-ed
186- # i indexes broadcastable pattern before subtensor
187- # j indexes broadcastable pattern after subtensor
188- j = 0
189- for i , x in enumerate (node .op .idx_list ):
190- # if it is not a slice, it will reduce the dimension, should
191- # not appear in the broascastable dimensions
192- if isinstance (x , slice ):
193- if i in old_axes :
194- new_axes .append (j )
195- j += 1
196- # now keep the broadcastable pattern of all
197- # items not appearing in subtensor list
198- for i in range (len (node .op .idx_list ), len (u .broadcastable )):
199- if i in old_axes :
200- new_axes .append (j )
201- j += 1
131+ idx_tuple = indices_from_subtensor (idx , node .op .idx_list )
202132
203- subt_x = node .op (u .owner .inputs [0 ], * node .inputs [1 :])
204- # Copy over previous output stacktrace
205- copy_stack_trace (node .outputs [0 ], subt_x )
133+ elem_inputs = elem .owner .inputs
134+ elem_bcast = elem .type .broadcastable
135+ if all (inp .type .broadcastable == elem_bcast for inp in elem_inputs ):
136+ # No need to worry about implicit broadcasting.
137+ indexed_inputs = [inp [idx_tuple ] for inp in elem_inputs ]
138+
139+ else :
140+ # The original indices may not make sense on some of the broadcasted dimensions
141+ new_idxs = [list (idx_tuple ) for _ in elem_inputs ]
142+ for dim , (dim_idx , dim_bcast_out , * dim_bcast_inputs ) in enumerate (
143+ zip (
144+ idx_tuple ,
145+ elem_bcast ,
146+ * (inp .type .broadcastable for inp in elem_inputs ),
147+ # Indices can be shorter than input ndims
148+ strict = False ,
149+ )
150+ ):
151+ if is_full_slice (dim_idx ):
152+ # Full slice can be safely applied to all inputs
153+ continue
206154
207- rbcast_subt_x = unbroadcast (subt_x , * new_axes )
208- # Copy over previous output stacktrace
209- # and stacktrace from previous unary operation
210- copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], rbcast_subt_x )
155+ if all (dim_bcast_inp == elem_bcast for dim_bcast_inp in dim_bcast_inputs ):
156+ # This dim is not broadcasted for any of the inputs, original index can be applied to all inputs
157+ continue
158+
159+ # Some dims are broadcasted, so we need to adapt their indices
160+ # Slice indexing keeps the dimension, so we use a full slice for broadcasted inputs
161+ # Integer indexing drops the dimension, so we index by zero for the broadcsated inputs
162+ safe_bcast_dim_idx = slice (None ) if isinstance (dim_idx , slice ) else 0
163+ for inp_idx , dim_bcast_inp in zip (new_idxs , dim_bcast_inputs , strict = True ):
164+ if dim_bcast_inp :
165+ inp_idx [dim ] = safe_bcast_dim_idx
211166
212- return [rbcast_subt_x ]
167+ indexed_inputs = [
168+ inp [tuple (new_idx )]
169+ for inp , new_idx in zip (elem_inputs , new_idxs , strict = True )
170+ ]
171+
172+ [old_out ] = node .outputs
173+
174+ # Copy stack trace to new inputs
175+ [copy_stack_trace (old_out , new_inp ) for new_inp in indexed_inputs ]
176+
177+ # Define elemwise operation on indexed inputs
178+ new_out = elem .owner .op (* indexed_inputs )
179+
180+ # Copy stack trace to new output
181+ copy_stack_trace ([old_out , * node .inputs ], new_out )
182+
183+ return [new_out ]
213184
214185
215186@register_canonicalize ("shape_unsafe" )
@@ -334,6 +305,51 @@ def local_subtensor_of_transpose(fgraph, node):
334305 return [new_out ]
335306
336307
308+ @register_canonicalize ("fast_compile" )
309+ @node_rewriter ([Subtensor ])
310+ def local_subtensor_of_unbroadcast (fgraph , node ):
311+ """
312+ Unbroadcast(x)[idx] => Unbroadcast(x[idx])
313+ """
314+ u = node .inputs [0 ]
315+ if u .owner is None or len (fgraph .clients [u ]) > 1 :
316+ return False
317+
318+ if isinstance (u .owner .op , Unbroadcast ):
319+ # Subtensor might reduce dim., adapt broadcast pattern accordingly
320+ old_axes = u .owner .op .axes
321+ new_axes = []
322+
323+ # loop through indices being subtensor-ed
324+ # i indexes broadcastable pattern before subtensor
325+ # j indexes broadcastable pattern after subtensor
326+ j = 0
327+ for i , x in enumerate (node .op .idx_list ):
328+ # if it is not a slice, it will reduce the dimension, should
329+ # not appear in the broascastable dimensions
330+ if isinstance (x , slice ):
331+ if i in old_axes :
332+ new_axes .append (j )
333+ j += 1
334+ # now keep the broadcastable pattern of all
335+ # items not appearing in subtensor list
336+ for i in range (len (node .op .idx_list ), len (u .broadcastable )):
337+ if i in old_axes :
338+ new_axes .append (j )
339+ j += 1
340+
341+ subt_x = node .op (u .owner .inputs [0 ], * node .inputs [1 :])
342+ # Copy over previous output stacktrace
343+ copy_stack_trace (node .outputs [0 ], subt_x )
344+
345+ rbcast_subt_x = unbroadcast (subt_x , * new_axes )
346+ # Copy over previous output stacktrace
347+ # and stacktrace from previous unary operation
348+ copy_stack_trace ([node .outputs [0 ], node .inputs [0 ]], rbcast_subt_x )
349+
350+ return [rbcast_subt_x ]
351+
352+
337353@register_infer_shape
338354@register_useless
339355@register_canonicalize
0 commit comments