Skip to content

Commit e1ee3a2

Browse files
committed
Generalize lift of Subtensor over Elemwise
Split off Subtensor of Unbroadcast into its own rewrite
1 parent 5c97e9f commit e1ee3a2

File tree

2 files changed

+256
-264
lines changed

2 files changed

+256
-264
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 110 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -110,108 +110,79 @@ def local_subtensor_of_dot(fgraph, node):
110110
return [r]
111111

112112

113-
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
114-
@register_canonicalize("fast_compile")
113+
@register_canonicalize("shape_unsafe")
114+
@register_specialize("shape_unsafe")
115115
@node_rewriter([Subtensor])
116-
def local_subtensor_lift(fgraph, node):
116+
def local_subtensor_of_elemwise(fgraph, node):
117+
"""Lift a Subtensor through an Elemwise and its implicit broadcasting behavior.
118+
119+
exp(x)[:, 0] -> exp(x[:, 0])
120+
add(x, y)[0] -> add(x[0], y[0])
121+
add(x[None], y)[2] -> add(x, y[2])
117122
"""
118-
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
123+
elem, *idx = node.inputs
119124

120-
Handles the following unary ops:
121-
elemwise(x,...)[idx] -> elemwise(x[idx],...)
122-
when x,... are broadcasted scalar or not broadcasted at all
123-
Unbroadcast(x)[idx] => Unbroadcast(x[idx])
125+
if not (elem.owner and isinstance(elem.owner.op, Elemwise)):
126+
return None
124127

125-
"""
126-
if isinstance(node.op, Subtensor):
127-
u = node.inputs[0]
128-
if u.owner is None or len(fgraph.clients[u]) > 1:
129-
return False
130-
131-
if isinstance(u.owner.op, Elemwise) and len(u.owner.inputs) == 1:
132-
idx = node.inputs[1:]
133-
x_idx = node.op(u.owner.inputs[0], *idx)
134-
# Copy over previous output stacktrace
135-
copy_stack_trace(node.outputs, x_idx)
136-
ret = u.owner.op(x_idx)
137-
# Copy over previous output stacktrace
138-
# and stacktrace from previous unary operation
139-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
140-
return [ret]
128+
if len(fgraph.clients[elem]) > 1:
129+
# Elemwise output is used beyond the Subtensor.
130+
# Get out to avoid repeated computations
131+
return None
141132

142-
if isinstance(u.owner.op, Elemwise):
143-
new_inputs = []
144-
if all(sum(i.type.broadcastable) == 0 for i in u.owner.inputs):
145-
# There is no broadcastable in the inputs
146-
idx = node.inputs[1:]
147-
new_inputs = [node.op(i, *idx) for i in u.owner.inputs]
148-
# Copy over previous output stacktrace
149-
copy_stack_trace(node.outputs[0], new_inputs)
150-
151-
ret = u.owner.op(*new_inputs)
152-
# Copy over previous output stacktrace
153-
# and stacktrace from previous unary operation
154-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
155-
return [ret]
156-
elif all(sum(i.type.broadcastable) in [i.ndim, 0] for i in u.owner.inputs):
157-
# There is no broadcastable in the inputs or it is scalar
158-
idx = node.inputs[1:]
159-
new_inputs = []
160-
for i in u.owner.inputs:
161-
if sum(i.type.broadcastable) == 0:
162-
new_inputs.append(node.op(i, *idx))
163-
else:
164-
# If the subtensor remove some dims, we must
165-
# lower the number of dimensions of this scalar.
166-
if node.outputs[0].ndim == i.ndim:
167-
new_inputs.append(i)
168-
else:
169-
new_inputs.append(
170-
i.dimshuffle(["x"] * node.outputs[0].ndim)
171-
)
172-
173-
# Copy over previous output stacktrace
174-
copy_stack_trace(node.outputs[0], new_inputs)
175-
176-
ret = u.owner.op(*new_inputs)
177-
# Copy over previous output stacktrace
178-
# and stacktrace from previous unary operation
179-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
180-
return [ret]
181-
182-
if isinstance(u.owner.op, Unbroadcast):
183-
# Subtensor might reduce dim., adapt broadcast pattern accordingly
184-
old_axes = u.owner.op.axes
185-
new_axes = []
186-
187-
# loop through indices being subtensor-ed
188-
# i indexes broadcastable pattern before subtensor
189-
# j indexes broadcastable pattern after subtensor
190-
j = 0
191-
for i, x in enumerate(node.op.idx_list):
192-
# if it is not a slice, it will reduce the dimension, should
193-
# not appear in the broascastable dimensions
194-
if isinstance(x, slice):
195-
if i in old_axes:
196-
new_axes.append(j)
197-
j += 1
198-
# now keep the broadcastable pattern of all
199-
# items not appearing in subtensor list
200-
for i in range(len(node.op.idx_list), len(u.broadcastable)):
201-
if i in old_axes:
202-
new_axes.append(j)
203-
j += 1
133+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
204134

205-
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
206-
# Copy over previous output stacktrace
207-
copy_stack_trace(node.outputs[0], subt_x)
135+
elem_inputs = elem.owner.inputs
136+
elem_bcast = elem.type.broadcastable
137+
if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs):
138+
# No need to worry about implicit broadcasting.
139+
indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]
140+
141+
else:
142+
# The original indices may not make sense on some of the broadcasted dimensions
143+
new_idxs = [list(idx_tuple) for _ in elem_inputs]
144+
for dim, (dim_idx, dim_bcast_out, *dim_bcast_inputs) in enumerate(
145+
zip(
146+
idx_tuple,
147+
elem_bcast,
148+
*(inp.type.broadcastable for inp in elem_inputs),
149+
# Indices can be shorter than input ndims
150+
strict=False,
151+
)
152+
):
153+
if is_full_slice(dim_idx):
154+
# Full slice can be safely applied to all inputs
155+
continue
208156

209-
rbcast_subt_x = unbroadcast(subt_x, *new_axes)
210-
# Copy over previous output stacktrace
211-
# and stacktrace from previous unary operation
212-
copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
157+
if all(dim_bcast_inp == elem_bcast for dim_bcast_inp in dim_bcast_inputs):
158+
# This dim is not broadcasted for any of the inputs, original index can be applied to all inputs
159+
continue
160+
161+
# Some dims are broadcasted, so we need to adapt their indices
162+
# Slice indexing keeps the dimension, so we use a full slice for broadcasted inputs
163+
# Integer indexing drops the dimension, so we index by zero for the broadcsated inputs
164+
safe_bcast_dim_idx = slice(None) if isinstance(dim_idx, slice) else 0
165+
for inp_idx, dim_bcast_inp in zip(new_idxs, dim_bcast_inputs, strict=True):
166+
if dim_bcast_inp:
167+
inp_idx[dim] = safe_bcast_dim_idx
213168

214-
return [rbcast_subt_x]
169+
indexed_inputs = [
170+
inp[tuple(new_idx)]
171+
for inp, new_idx in zip(elem_inputs, new_idxs, strict=True)
172+
]
173+
174+
[old_out] = node.outputs
175+
176+
# Copy stack trace to new inputs
177+
[copy_stack_trace(old_out, new_inp) for new_inp in indexed_inputs]
178+
179+
# Define elemwise operation on indexed inputs
180+
new_out = elem.owner.op(*indexed_inputs)
181+
182+
# Copy stack trace to new output
183+
copy_stack_trace([old_out, *node.inputs], new_out)
184+
185+
return [new_out]
215186

216187

217188
@register_canonicalize("shape_unsafe")
@@ -336,6 +307,51 @@ def local_subtensor_of_transpose(fgraph, node):
336307
return [new_out]
337308

338309

310+
@register_canonicalize("fast_compile")
311+
@node_rewriter([Subtensor])
312+
def local_subtensor_of_unbroadcast(fgraph, node):
313+
"""
314+
Unbroadcast(x)[idx] => Unbroadcast(x[idx])
315+
"""
316+
u = node.inputs[0]
317+
if u.owner is None or len(fgraph.clients[u]) > 1:
318+
return False
319+
320+
if isinstance(u.owner.op, Unbroadcast):
321+
# Subtensor might reduce dim., adapt broadcast pattern accordingly
322+
old_axes = u.owner.op.axes
323+
new_axes = []
324+
325+
# loop through indices being subtensor-ed
326+
# i indexes broadcastable pattern before subtensor
327+
# j indexes broadcastable pattern after subtensor
328+
j = 0
329+
for i, x in enumerate(node.op.idx_list):
330+
# if it is not a slice, it will reduce the dimension, should
331+
# not appear in the broascastable dimensions
332+
if isinstance(x, slice):
333+
if i in old_axes:
334+
new_axes.append(j)
335+
j += 1
336+
# now keep the broadcastable pattern of all
337+
# items not appearing in subtensor list
338+
for i in range(len(node.op.idx_list), len(u.broadcastable)):
339+
if i in old_axes:
340+
new_axes.append(j)
341+
j += 1
342+
343+
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
344+
# Copy over previous output stacktrace
345+
copy_stack_trace(node.outputs[0], subt_x)
346+
347+
rbcast_subt_x = unbroadcast(subt_x, *new_axes)
348+
# Copy over previous output stacktrace
349+
# and stacktrace from previous unary operation
350+
copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
351+
352+
return [rbcast_subt_x]
353+
354+
339355
@register_infer_shape
340356
@register_useless
341357
@register_canonicalize

0 commit comments

Comments
 (0)