Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pytensor/scan/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1500,6 +1500,10 @@ def make_thunk(self, node, storage_map, compute_map, no_recycling, impl=None):
node_input_storage = [storage_map[r] for r in node.inputs]
node_output_storage = [storage_map[r] for r in node.outputs]

# HACK: Here to handle Blockwise Scans
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not a hack, compute map for inner graphs is useless.

Instead of creating a compute_map make an if/else below where rval is returned and return an rval that doesn't try to assign to compute_map if it was None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ricardoV94 is this change alright?

if compute_map is None:
compute_map = {out: [False] for out in node.outputs}

# Analyse the compile inner function to determine which inputs and
# outputs are on the gpu and speed up some checks during the execution
outs_is_tensor = [
Expand Down
12 changes: 12 additions & 0 deletions tests/scan/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from pytensor.compile.sharedvalue import shared
from pytensor.configdefaults import config
from pytensor.gradient import NullTypeGradError, Rop, disconnected_grad, grad, hessian
from pytensor.graph import vectorize_graph
from pytensor.graph.basic import Apply, ancestors, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
Expand Down Expand Up @@ -1178,6 +1179,17 @@ def get_sum_of_grad(input0, input1):

utt.verify_grad(get_sum_of_grad, inputs_test_values, rng=rng)

def test_blockwise_scan(self):
x = pt.tensor("x", shape=())
out, _ = scan(lambda x: x + 1, outputs_info=[x], n_steps=10)
x_vec = pt.tensor("x_vec", shape=(None,))
out_vec = vectorize_graph(out, {x: x_vec})

fn = function([x_vec], out_vec)
o1 = fn([1, 2, 3])
o2 = np.arange(2, 12) + np.arange(3).reshape(-1, 1)
assert np.allclose(o1, o2)

def test_connection_pattern(self):
"""Test `Scan.connection_pattern` in the presence of recurrent outputs with multiple taps."""

Expand Down