Skip to content

Commit bc9baac

Browse files
committed
Generalize lift of Subtensor over Elemwise
Split off Subtensor of Unbroadcast into its own rewrite
1 parent f5fd84e commit bc9baac

File tree

2 files changed

+256
-264
lines changed

2 files changed

+256
-264
lines changed

pytensor/tensor/rewriting/subtensor_lift.py

Lines changed: 110 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -108,108 +108,79 @@ def local_subtensor_of_dot(fgraph, node):
108108
return [r]
109109

110110

111-
# fast_compile to allow opt subtensor(cast{float32}(make_vector))
112-
@register_canonicalize("fast_compile")
111+
@register_canonicalize("shape_unsafe")
112+
@register_specialize("shape_unsafe")
113113
@node_rewriter([Subtensor])
114-
def local_subtensor_lift(fgraph, node):
114+
def local_subtensor_of_elemwise(fgraph, node):
115+
"""Lift a Subtensor through an Elemwise and its implicit broadcasting behavior.
116+
117+
exp(x)[:, 0] -> exp(x[:, 0])
118+
add(x, y)[0] -> add(x[0], y[0])
119+
add(x[None], y)[2] -> add(x, y[2])
115120
"""
116-
unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
121+
elem, *idx = node.inputs
117122

118-
Handles the following unary ops:
119-
elemwise(x,...)[idx] -> elemwise(x[idx],...)
120-
when x,... are broadcasted scalar or not broadcasted at all
121-
Unbroadcast(x)[idx] => Unbroadcast(x[idx])
123+
if not (elem.owner and isinstance(elem.owner.op, Elemwise)):
124+
return None
122125

123-
"""
124-
if isinstance(node.op, Subtensor):
125-
u = node.inputs[0]
126-
if u.owner is None or len(fgraph.clients[u]) > 1:
127-
return False
128-
129-
if isinstance(u.owner.op, Elemwise) and len(u.owner.inputs) == 1:
130-
idx = node.inputs[1:]
131-
x_idx = node.op(u.owner.inputs[0], *idx)
132-
# Copy over previous output stacktrace
133-
copy_stack_trace(node.outputs, x_idx)
134-
ret = u.owner.op(x_idx)
135-
# Copy over previous output stacktrace
136-
# and stacktrace from previous unary operation
137-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
138-
return [ret]
126+
if len(fgraph.clients[elem]) > 1:
127+
# Elemwise output is used beyond the Subtensor.
128+
# Get out to avoid repeated computations
129+
return None
139130

140-
if isinstance(u.owner.op, Elemwise):
141-
new_inputs = []
142-
if all(sum(i.type.broadcastable) == 0 for i in u.owner.inputs):
143-
# There is no broadcastable in the inputs
144-
idx = node.inputs[1:]
145-
new_inputs = [node.op(i, *idx) for i in u.owner.inputs]
146-
# Copy over previous output stacktrace
147-
copy_stack_trace(node.outputs[0], new_inputs)
148-
149-
ret = u.owner.op(*new_inputs)
150-
# Copy over previous output stacktrace
151-
# and stacktrace from previous unary operation
152-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
153-
return [ret]
154-
elif all(sum(i.type.broadcastable) in [i.ndim, 0] for i in u.owner.inputs):
155-
# There is no broadcastable in the inputs or it is scalar
156-
idx = node.inputs[1:]
157-
new_inputs = []
158-
for i in u.owner.inputs:
159-
if sum(i.type.broadcastable) == 0:
160-
new_inputs.append(node.op(i, *idx))
161-
else:
162-
# If the subtensor remove some dims, we must
163-
# lower the number of dimensions of this scalar.
164-
if node.outputs[0].ndim == i.ndim:
165-
new_inputs.append(i)
166-
else:
167-
new_inputs.append(
168-
i.dimshuffle(["x"] * node.outputs[0].ndim)
169-
)
170-
171-
# Copy over previous output stacktrace
172-
copy_stack_trace(node.outputs[0], new_inputs)
173-
174-
ret = u.owner.op(*new_inputs)
175-
# Copy over previous output stacktrace
176-
# and stacktrace from previous unary operation
177-
copy_stack_trace([node.outputs[0], node.inputs[0]], ret)
178-
return [ret]
179-
180-
if isinstance(u.owner.op, Unbroadcast):
181-
# Subtensor might reduce dim., adapt broadcast pattern accordingly
182-
old_axes = u.owner.op.axes
183-
new_axes = []
184-
185-
# loop through indices being subtensor-ed
186-
# i indexes broadcastable pattern before subtensor
187-
# j indexes broadcastable pattern after subtensor
188-
j = 0
189-
for i, x in enumerate(node.op.idx_list):
190-
# if it is not a slice, it will reduce the dimension, should
191-
# not appear in the broascastable dimensions
192-
if isinstance(x, slice):
193-
if i in old_axes:
194-
new_axes.append(j)
195-
j += 1
196-
# now keep the broadcastable pattern of all
197-
# items not appearing in subtensor list
198-
for i in range(len(node.op.idx_list), len(u.broadcastable)):
199-
if i in old_axes:
200-
new_axes.append(j)
201-
j += 1
131+
idx_tuple = indices_from_subtensor(idx, node.op.idx_list)
202132

203-
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
204-
# Copy over previous output stacktrace
205-
copy_stack_trace(node.outputs[0], subt_x)
133+
elem_inputs = elem.owner.inputs
134+
elem_bcast = elem.type.broadcastable
135+
if all(inp.type.broadcastable == elem_bcast for inp in elem_inputs):
136+
# No need to worry about implicit broadcasting.
137+
indexed_inputs = [inp[idx_tuple] for inp in elem_inputs]
138+
139+
else:
140+
# The original indices may not make sense on some of the broadcasted dimensions
141+
new_idxs = [list(idx_tuple) for _ in elem_inputs]
142+
for dim, (dim_idx, dim_bcast_out, *dim_bcast_inputs) in enumerate(
143+
zip(
144+
idx_tuple,
145+
elem_bcast,
146+
*(inp.type.broadcastable for inp in elem_inputs),
147+
# Indices can be shorter than input ndims
148+
strict=False,
149+
)
150+
):
151+
if is_full_slice(dim_idx):
152+
# Full slice can be safely applied to all inputs
153+
continue
206154

207-
rbcast_subt_x = unbroadcast(subt_x, *new_axes)
208-
# Copy over previous output stacktrace
209-
# and stacktrace from previous unary operation
210-
copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
155+
if all(dim_bcast_inp == elem_bcast for dim_bcast_inp in dim_bcast_inputs):
156+
# This dim is not broadcasted for any of the inputs, original index can be applied to all inputs
157+
continue
158+
159+
# Some dims are broadcasted, so we need to adapt their indices
160+
# Slice indexing keeps the dimension, so we use a full slice for broadcasted inputs
161+
# Integer indexing drops the dimension, so we index by zero for the broadcsated inputs
162+
safe_bcast_dim_idx = slice(None) if isinstance(dim_idx, slice) else 0
163+
for inp_idx, dim_bcast_inp in zip(new_idxs, dim_bcast_inputs, strict=True):
164+
if dim_bcast_inp:
165+
inp_idx[dim] = safe_bcast_dim_idx
211166

212-
return [rbcast_subt_x]
167+
indexed_inputs = [
168+
inp[tuple(new_idx)]
169+
for inp, new_idx in zip(elem_inputs, new_idxs, strict=True)
170+
]
171+
172+
[old_out] = node.outputs
173+
174+
# Copy stack trace to new inputs
175+
[copy_stack_trace(old_out, new_inp) for new_inp in indexed_inputs]
176+
177+
# Define elemwise operation on indexed inputs
178+
new_out = elem.owner.op(*indexed_inputs)
179+
180+
# Copy stack trace to new output
181+
copy_stack_trace([old_out, *node.inputs], new_out)
182+
183+
return [new_out]
213184

214185

215186
@register_canonicalize("shape_unsafe")
@@ -334,6 +305,51 @@ def local_subtensor_of_transpose(fgraph, node):
334305
return [new_out]
335306

336307

308+
@register_canonicalize("fast_compile")
309+
@node_rewriter([Subtensor])
310+
def local_subtensor_of_unbroadcast(fgraph, node):
311+
"""
312+
Unbroadcast(x)[idx] => Unbroadcast(x[idx])
313+
"""
314+
u = node.inputs[0]
315+
if u.owner is None or len(fgraph.clients[u]) > 1:
316+
return False
317+
318+
if isinstance(u.owner.op, Unbroadcast):
319+
# Subtensor might reduce dim., adapt broadcast pattern accordingly
320+
old_axes = u.owner.op.axes
321+
new_axes = []
322+
323+
# loop through indices being subtensor-ed
324+
# i indexes broadcastable pattern before subtensor
325+
# j indexes broadcastable pattern after subtensor
326+
j = 0
327+
for i, x in enumerate(node.op.idx_list):
328+
# if it is not a slice, it will reduce the dimension, should
329+
# not appear in the broascastable dimensions
330+
if isinstance(x, slice):
331+
if i in old_axes:
332+
new_axes.append(j)
333+
j += 1
334+
# now keep the broadcastable pattern of all
335+
# items not appearing in subtensor list
336+
for i in range(len(node.op.idx_list), len(u.broadcastable)):
337+
if i in old_axes:
338+
new_axes.append(j)
339+
j += 1
340+
341+
subt_x = node.op(u.owner.inputs[0], *node.inputs[1:])
342+
# Copy over previous output stacktrace
343+
copy_stack_trace(node.outputs[0], subt_x)
344+
345+
rbcast_subt_x = unbroadcast(subt_x, *new_axes)
346+
# Copy over previous output stacktrace
347+
# and stacktrace from previous unary operation
348+
copy_stack_trace([node.outputs[0], node.inputs[0]], rbcast_subt_x)
349+
350+
return [rbcast_subt_x]
351+
352+
337353
@register_infer_shape
338354
@register_useless
339355
@register_canonicalize

0 commit comments

Comments
 (0)