Skip to content

Commit 468a83e

Browse files
committed
Generalize lift of Subtensor over Elemwise
Split off Subtensor of Unbroadcast into its own rewrite
1 parent f8cfe6a commit 468a83e

File tree

2 files changed

+256
-264
lines changed

2 files changed

+256
-264
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 110 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -102,108 +102,79 @@ def local_subtensor_of_dot(fgraph, node):
102102
return [r]
103103

104104

105-
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
106-
@register_canonicalize("fast_compile")
105+
@register_canonicalize("shape_unsafe")
106+
@register_specialize("shape_unsafe")
107107
@node_rewriter([Subtensor])
108-
def local_subtensor_lift(fgraph, node):
108+
def local_subtensor_of_elemwise(fgraph, node):
109+
"""Lift a Subtensor through an Elemwise and its implicit broadcasting behavior.
110+
111+
exp(x)[:, 0] -> exp(x[:, 0])
112+
add(x, y)[0] -> add(x[0], y[0])
113+
add(x[None], y)[2] -> add(x, y[2])
109114
"""
110-
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
115+
elem, *idx = node.inputs
111116

112-
Handles the following unary ops:
113-
elemwise(x,...)[idx] -> elemwise(x[idx],...)
114-
when x,... are broadcasted scalar or not broadcasted at all
115-
Unbroadcast(x)[idx] => Unbroadcast(x[idx])
117+
if not (elem.owner and isinstance(elem.owner.op, Elemwise)):
118+
return None
116119

117-
"""
118-
if isinstance(node.op, Subtensor):
119-
u = node.inputs[0]
120-
if u.owner is None or len(fgraph.clients[u]) > 1:
121-
return False
122-
123-
if isinstance(u.owner.op, Elemwise) and len(u.owner.inputs) == 1:
124-
idx = node.inputs[1:]
125-
x_idx = node.op(u.owner.inputs[0], *idx)
126-
# Copy over previous output stacktrace
127-
copy_stack_trace(node.outputs, x_idx)
128-
ret = u.owner.op(x_idx)
129-
# Copy over previous output stacktrace
130-
# and stacktrace from previous unary operation
131-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
132-
return [ret]
120+
if len(fgraph.clients[elem]) > 1:
121+
# Elemwise output is used beyond the Subtensor.
122+
# Get out to avoid repeated computations
123+
return None
133124

134-
if isinstance(u.owner.op, Elemwise):
135-
new_inputs = []
136-
if all(sum(i.type.broadcastable) == 0 for i in u.owner.inputs):
137-
# There is no broadcastable in the inputs
138-
idx = node.inputs[1:]
139-
new_inputs = [node.op(i, *idx) for i in u.owner.inputs]
140-
# Copy over previous output stacktrace
141-
copy_stack_trace(node.outputs[0], new_inputs)
142-
143-
ret = u.owner.op(*new_inputs)
144-
# Copy over previous output stacktrace
145-
# and stacktrace from previous unary operation
146-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
147-
return [ret]
148-
elif all(sum(i.type.broadcastable) in [i.ndim, 0] for i in u.owner.inputs):
149-
# There is no broadcastable in the inputs or it is scalar
150-
idx = node.inputs[1:]
151-
new_inputs = []
152-
for i in u.owner.inputs:
153-
if sum(i.type.broadcastable) == 0:
154-
new_inputs.append(node.op(i, *idx))
155-
else:
156-
# If the subtensor remove some dims, we must
157-
# lower the number of dimensions of this scalar.
158-
if node.outputs[0].ndim == i.ndim:
159-
new_inputs.append(i)
160-
else:
161-
new_inputs.append(
162-
i.dimshuffle(["x"] * node.outputs[0].ndim)
163-
)
164-
165-
# Copy over previous output stacktrace
166-
copy_stack_trace(node.outputs[0], new_inputs)
167-
168-
ret = u.owner.op(*new_inputs)
169-
# Copy over previous output stacktrace
170-
# and stacktrace from previous unary operation
171-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
172-
return [ret]
173-
174-
if isinstance(u.owner.op, Unbroadcast):
175-
# Subtensor might reduce dim., adapt broadcast pattern accordingly
176-
old_axes = u.owner.op.axes
177-
new_axes = []
178-
179-
# loop through indices being subtensor-ed
180-
# i indexes broadcastable pattern before subtensor
181-
# j indexes broadcastable pattern after subtensor
182-
j = 0
183-
for i, x in enumerate(node.op.idx_list):
184-
# if it is not a slice, it will reduce the dimension, should
185-
# not appear in the broascastable dimensions
186-
if isinstance(x, slice):
187-
if i in old_axes:
188-
new_axes.append(j)
189-
j += 1
190-
# now keep the broadcastable pattern of all
191-
# items not appearing in subtensor list
192-
for i in range(len(node.op.idx_list), len(u.broadcastable)):
193-
if i in old_axes:
194-
new_axes.append(j)
195-
j += 1
125+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
196126

197-
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
198-
# Copy over previous output stacktrace
199-
copy_stack_trace(node.outputs[0], subt_x)
127+
elem_inputs = elem.owner.inputs
128+
elem_bcast = elem.type.broadcastable
129+
if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs):
130+
# No need to worry about implicit broadcasting.
131+
indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]
132+
133+
else:
134+
# The original indices may not make sense on some of the broadcasted dimensions
135+
new_idxs = [list(idx_tuple) for _ in elem_inputs]
136+
for dim, (dim_idx, dim_bcast_out, *dim_bcast_inputs) in enumerate(
137+
zip(
138+
idx_tuple,
139+
elem_bcast,
140+
*(inp.type.broadcastable for inp in elem_inputs),
141+
# Indices can be shorter than input ndims
142+
strict=False,
143+
)
144+
):
145+
if is_full_slice(dim_idx):
146+
# Full slice can be safely applied to all inputs
147+
continue
200148

201-
rbcast_subt_x = unbroadcast(subt_x, *new_axes)
202-
# Copy over previous output stacktrace
203-
# and stacktrace from previous unary operation
204-
copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
149+
if all(dim_bcast_inp == elem_bcast for dim_bcast_inp in dim_bcast_inputs):
150+
# This dim is not broadcasted for any of the inputs, original index can be applied to all inputs
151+
continue
152+
153+
# Some dims are broadcasted, so we need to adapt their indices
154+
# Slice indexing keeps the dimension, so we use a full slice for broadcasted inputs
155+
# Integer indexing drops the dimension, so we index by zero for the broadcsated inputs
156+
safe_bcast_dim_idx = slice(None) if isinstance(dim_idx, slice) else 0
157+
for inp_idx, dim_bcast_inp in zip(new_idxs, dim_bcast_inputs, strict=True):
158+
if dim_bcast_inp:
159+
inp_idx[dim] = safe_bcast_dim_idx
205160

206-
return [rbcast_subt_x]
161+
indexed_inputs = [
162+
inp[tuple(new_idx)]
163+
for inp, new_idx in zip(elem_inputs, new_idxs, strict=True)
164+
]
165+
166+
[old_out] = node.outputs
167+
168+
# Copy stack trace to new inputs
169+
[copy_stack_trace(old_out, new_inp) for new_inp in indexed_inputs]
170+
171+
# Define elemwise operation on indexed inputs
172+
new_out = elem.owner.op(*indexed_inputs)
173+
174+
# Copy stack trace to new output
175+
copy_stack_trace([old_out, *node.inputs], new_out)
176+
177+
return [new_out]
207178

208179

209180
@register_canonicalize("shape_unsafe")
@@ -328,6 +299,51 @@ def local_subtensor_of_transpose(fgraph, node):
328299
return [new_out]
329300

330301

302+
@register_canonicalize("fast_compile")
303+
@node_rewriter([Subtensor])
304+
def local_subtensor_of_unbroadcast(fgraph, node):
305+
"""
306+
Unbroadcast(x)[idx] => Unbroadcast(x[idx])
307+
"""
308+
u = node.inputs[0]
309+
if u.owner is None or len(fgraph.clients[u]) > 1:
310+
return False
311+
312+
if isinstance(u.owner.op, Unbroadcast):
313+
# Subtensor might reduce dim., adapt broadcast pattern accordingly
314+
old_axes = u.owner.op.axes
315+
new_axes = []
316+
317+
# loop through indices being subtensor-ed
318+
# i indexes broadcastable pattern before subtensor
319+
# j indexes broadcastable pattern after subtensor
320+
j = 0
321+
for i, x in enumerate(node.op.idx_list):
322+
# if it is not a slice, it will reduce the dimension, should
323+
# not appear in the broascastable dimensions
324+
if isinstance(x, slice):
325+
if i in old_axes:
326+
new_axes.append(j)
327+
j += 1
328+
# now keep the broadcastable pattern of all
329+
# items not appearing in subtensor list
330+
for i in range(len(node.op.idx_list), len(u.broadcastable)):
331+
if i in old_axes:
332+
new_axes.append(j)
333+
j += 1
334+
335+
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
336+
# Copy over previous output stacktrace
337+
copy_stack_trace(node.outputs[0], subt_x)
338+
339+
rbcast_subt_x = unbroadcast(subt_x, *new_axes)
340+
# Copy over previous output stacktrace
341+
# and stacktrace from previous unary operation
342+
copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
343+
344+
return [rbcast_subt_x]
345+
346+
331347
@register_infer_shape
332348
@register_useless
333349
@register_canonicalize

0 commit comments

Comments
 (0)