-
Notifications
You must be signed in to change notification settings - Fork 139
Description
Description
In trying to simplify Function.__call__
, (see #1024 and #222), I noticed some complicated logic to check if inputs marked as mutable (or borrowable) are not aliasing to the same memory of each other.
pytensor/pytensor/compile/function/types.py
Lines 888 to 933 in be358ed
if ( | |
not self.trust_input | |
and | |
# The getattr is only needed for old pickle | |
getattr(self, "_check_for_aliased_inputs", True) | |
): | |
# Collect aliased inputs among the storage space | |
args_share_memory = [] | |
for i in range(len(self.input_storage)): | |
i_var = self.maker.inputs[i].variable | |
i_val = self.input_storage[i].storage[0] | |
if hasattr(i_var.type, "may_share_memory"): | |
is_aliased = False | |
for j in range(len(args_share_memory)): | |
group_j = zip( | |
[ | |
self.maker.inputs[k].variable | |
for k in args_share_memory[j] | |
], | |
[ | |
self.input_storage[k].storage[0] | |
for k in args_share_memory[j] | |
], | |
) | |
if any( | |
( | |
var.type is i_var.type | |
and var.type.may_share_memory(val, i_val) | |
) | |
for (var, val) in group_j | |
): | |
is_aliased = True | |
args_share_memory[j].append(i) | |
break | |
if not is_aliased: | |
args_share_memory.append([i]) | |
# Check for groups of more than one argument that share memory | |
for group in args_share_memory: | |
if len(group) > 1: | |
# copy all but the first | |
for j in group[1:]: | |
self.input_storage[j].storage[0] = copy.copy( | |
self.input_storage[j].storage[0] | |
) |
To avoid erroneous computation, __call__
tries to copy aliased inputs. However this logic is wrong because it assumes only variables with the same type can be aliased which doesn't make sense. See the example below where a matrix and a vector are aliased, which fails the check and return wrong values and corrupted input y
which was not marked as mutable
import pytensor
import pytensor.tensor as pt
from pytensor import In
import numpy as np
x = pt.vector()
y = pt.matrix()
fn = pytensor.function([In(x, mutable=True), In(y, mutable=False)], [x * 2, y * 2])
fn.dprint(print_destroy_map=True)
# Mul [id A] d={0: [1]} 0
# ├─ [2.] [id B]
# └─ <Vector(float64, shape=(?,))> [id C]
# Mul [id D] d={0: [1]} 1
# ├─ [[2.]] [id E]
# └─ <Matrix(float64, shape=(?, ?))> [id F]
y_val = np.ones((2, 5))
x_val = y_val[0] # x is an alias of y
res1, res2 = fn(x_val, y_val)
print(res1)
# [2. 2. 2. 2. 2.]
print(res2) # Wrong
# [[4. 4. 4. 4. 4.]
# [2. 2. 2. 2. 2.]]
print(y_val) # Corrupted
# [[2. 2. 2. 2. 2.]
# [1. 1. 1. 1. 1.]]
My suggestion is not to make the check for alias more robust (and therefore increase the Function call overhead), but instead to forego it completely. If users indicated that an input is mutable it shouldn't be too surprising that views of that input (or other variables sharing the same underlying memory) would also be corrupted.