Skip to content

Commit 7509359

Browse files
committed
Generalize lift of Subtensor over Elemwise
Split off Subtensor of Unbroadcast into its own rewrite
1 parent 605e733 commit 7509359

File tree

2 files changed

+256
-264
lines changed

2 files changed

+256
-264
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 110 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -110,108 +110,79 @@ def local_subtensor_of_dot(fgraph, node):
110110
return [r]
111111

112112

113-
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
114-
@register_canonicalize("fast_compile")
113+
@register_canonicalize("shape_unsafe")
114+
@register_specialize("shape_unsafe")
115115
@node_rewriter([Subtensor])
116-
def local_subtensor_lift(fgraph, node):
116+
def local_subtensor_of_elemwise(fgraph, node):
117+
"""Lift a Subtensor through an Elemwise and its implicit broadcasting behavior.
118+
119+
exp(x)[:, 0] -> exp(x[:, 0])
120+
add(x, y)[0] -> add(x[0], y[0])
121+
add(x[None], y)[2] -> add(x, y[2])
117122
"""
118-
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
123+
elem, *idx = node.inputs
119124

120-
Handles the following unary ops:
121-
elemwise(x,...)[idx] -> elemwise(x[idx],...)
122-
when x,... are broadcasted scalar or not broadcasted at all
123-
Unbroadcast(x)[idx] => Unbroadcast(x[idx])
125+
if not (elem.owner and isinstance(elem.owner.op, Elemwise)):
126+
return None
124127

125-
"""
126-
if isinstance(node.op, Subtensor):
127-
u = node.inputs[0]
128-
if u.owner is None or len(fgraph.clients[u]) > 1:
129-
return False
130-
131-
if isinstance(u.owner.op, Elemwise) and len(u.owner.inputs) == 1:
132-
idx = node.inputs[1:]
133-
x_idx = node.op(u.owner.inputs[0], *idx)
134-
# Copy over previous output stacktrace
135-
copy_stack_trace(node.outputs, x_idx)
136-
ret = u.owner.op(x_idx)
137-
# Copy over previous output stacktrace
138-
# and stacktrace from previous unary operation
139-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
140-
return [ret]
128+
if len(fgraph.clients[elem]) > 1:
129+
# Elemwise output is used beyond the Subtensor.
130+
# Get out to avoid repeated computations
131+
return None
141132

142-
if isinstance(u.owner.op, Elemwise):
143-
new_inputs = []
144-
if all(sum(i.type.broadcastable) == 0 for i in u.owner.inputs):
145-
# There is no broadcastable in the inputs
146-
idx = node.inputs[1:]
147-
new_inputs = [node.op(i, *idx) for i in u.owner.inputs]
148-
# Copy over previous output stacktrace
149-
copy_stack_trace(node.outputs[0], new_inputs)
150-
151-
ret = u.owner.op(*new_inputs)
152-
# Copy over previous output stacktrace
153-
# and stacktrace from previous unary operation
154-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
155-
return [ret]
156-
elif all(sum(i.type.broadcastable) in [i.ndim, 0] for i in u.owner.inputs):
157-
# There is no broadcastable in the inputs or it is scalar
158-
idx = node.inputs[1:]
159-
new_inputs = []
160-
for i in u.owner.inputs:
161-
if sum(i.type.broadcastable) == 0:
162-
new_inputs.append(node.op(i, *idx))
163-
else:
164-
# If the subtensor remove some dims, we must
165-
# lower the number of dimensions of this scalar.
166-
if node.outputs[0].ndim == i.ndim:
167-
new_inputs.append(i)
168-
else:
169-
new_inputs.append(
170-
i.dimshuffle(["x"] * node.outputs[0].ndim)
171-
)
172-
173-
# Copy over previous output stacktrace
174-
copy_stack_trace(node.outputs[0], new_inputs)
175-
176-
ret = u.owner.op(*new_inputs)
177-
# Copy over previous output stacktrace
178-
# and stacktrace from previous unary operation
179-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
180-
return [ret]
181-
182-
if isinstance(u.owner.op, Unbroadcast):
183-
# Subtensor might reduce dim., adapt broadcast pattern accordingly
184-
old_axes = u.owner.op.axes
185-
new_axes = []
186-
187-
# loop through indices being subtensor-ed
188-
# i indexes broadcastable pattern before subtensor
189-
# j indexes broadcastable pattern after subtensor
190-
j = 0
191-
for i, x in enumerate(node.op.idx_list):
192-
# if it is not a slice, it will reduce the dimension, should
193-
# not appear in the broascastable dimensions
194-
if isinstance(x, slice):
195-
if i in old_axes:
196-
new_axes.append(j)
197-
j += 1
198-
# now keep the broadcastable pattern of all
199-
# items not appearing in subtensor list
200-
for i in range(len(node.op.idx_list), len(u.broadcastable)):
201-
if i in old_axes:
202-
new_axes.append(j)
203-
j += 1
133+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
204134

205-
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
206-
# Copy over previous output stacktrace
207-
copy_stack_trace(node.outputs[0], subt_x)
135+
elem_inputs = elem.owner.inputs
136+
elem_bcast = elem.type.broadcastable
137+
if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs):
138+
# No need to worry about implicit broadcasting.
139+
indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]
140+
141+
else:
142+
# The original indices may not make sense on some of the broadcasted dimensions
143+
new_idxs = [list(idx_tuple) for _ in elem_inputs]
144+
for dim, (dim_idx, dim_bcast_out, *dim_bcast_inputs) in enumerate(
145+
zip(
146+
idx_tuple,
147+
elem_bcast,
148+
*(inp.type.broadcastable for inp in elem_inputs),
149+
# Indices can be shorter than input ndims
150+
strict=False,
151+
)
152+
):
153+
if is_full_slice(dim_idx):
154+
# Full slice can be safely applied to all inputs
155+
continue
208156

209-
rbcast_subt_x = unbroadcast(subt_x, *new_axes)
210-
# Copy over previous output stacktrace
211-
# and stacktrace from previous unary operation
212-
copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
157+
if all(dim_bcast_inp == elem_bcast for dim_bcast_inp in dim_bcast_inputs):
158+
# This dim is not broadcasted for any of the inputs, original index can be applied to all inputs
159+
continue
160+
161+
# Some dims are broadcasted, so we need to adapt their indices
162+
# Slice indexing keeps the dimension, so we use a full slice for broadcasted inputs
163+
# Integer indexing drops the dimension, so we index by zero for the broadcsated inputs
164+
safe_bcast_dim_idx = slice(None) if isinstance(dim_idx, slice) else 0
165+
for inp_idx, dim_bcast_inp in zip(new_idxs, dim_bcast_inputs, strict=True):
166+
if dim_bcast_inp:
167+
inp_idx[dim] = safe_bcast_dim_idx
213168

214-
return [rbcast_subt_x]
169+
indexed_inputs = [
170+
inp[tuple(new_idx)]
171+
for inp, new_idx in zip(elem_inputs, new_idxs, strict=True)
172+
]
173+
174+
[old_out] = node.outputs
175+
176+
# Copy stack trace to new inputs
177+
[copy_stack_trace(old_out, new_inp) for new_inp in indexed_inputs]
178+
179+
# Define elemwise operation on indexed inputs
180+
new_out = elem.owner.op(*indexed_inputs)
181+
182+
# Copy stack trace to new output
183+
copy_stack_trace([old_out, *node.inputs], new_out)
184+
185+
return [new_out]
215186

216187

217188
@register_canonicalize("shape_unsafe")
@@ -337,6 +308,51 @@ def local_subtensor_of_transpose(fgraph, node):
337308
return [new_out]
338309

339310

311+
@register_canonicalize("fast_compile")
312+
@node_rewriter([Subtensor])
313+
def local_subtensor_of_unbroadcast(fgraph, node):
314+
"""
315+
Unbroadcast(x)[idx] => Unbroadcast(x[idx])
316+
"""
317+
u = node.inputs[0]
318+
if u.owner is None or len(fgraph.clients[u]) > 1:
319+
return False
320+
321+
if isinstance(u.owner.op, Unbroadcast):
322+
# Subtensor might reduce dim., adapt broadcast pattern accordingly
323+
old_axes = u.owner.op.axes
324+
new_axes = []
325+
326+
# loop through indices being subtensor-ed
327+
# i indexes broadcastable pattern before subtensor
328+
# j indexes broadcastable pattern after subtensor
329+
j = 0
330+
for i, x in enumerate(node.op.idx_list):
331+
# if it is not a slice, it will reduce the dimension, should
332+
# not appear in the broascastable dimensions
333+
if isinstance(x, slice):
334+
if i in old_axes:
335+
new_axes.append(j)
336+
j += 1
337+
# now keep the broadcastable pattern of all
338+
# items not appearing in subtensor list
339+
for i in range(len(node.op.idx_list), len(u.broadcastable)):
340+
if i in old_axes:
341+
new_axes.append(j)
342+
j += 1
343+
344+
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
345+
# Copy over previous output stacktrace
346+
copy_stack_trace(node.outputs[0], subt_x)
347+
348+
rbcast_subt_x = unbroadcast(subt_x, *new_axes)
349+
# Copy over previous output stacktrace
350+
# and stacktrace from previous unary operation
351+
copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
352+
353+
return [rbcast_subt_x]
354+
355+
340356
@register_infer_shape
341357
@register_useless
342358
@register_canonicalize

0 commit comments

Comments
 (0)