Skip to content
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar import ScalarLoop
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
Expand All @@ -11,6 +12,7 @@
@pytorch_funcify.register(Elemwise)
def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op

base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)

def check_special_scipy(func_name):
Expand All @@ -33,6 +35,9 @@
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)

elif isinstance(scalar_op, ScalarLoop):
return elemwise_ravel_fn(base_fn, op, node, **kwargs)

else:

def elemwise_fn(*inputs):
Expand Down Expand Up @@ -176,3 +181,37 @@
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm

return softmax_grad


def elemwise_ravel_fn(base_fn, op, node, **kwargs):
"""
Dispatch methods using `.item()` (ScalarLoop + Elemwise) is common, but vmap
in torch has a limitation: https://github.com/pymc-devs/pytensor/issues/1031,
Instead, we can ravel all the inputs, broadcasted according to torch
"""

n_outputs = len(node.outputs)

def elemwise_fn(*inputs):
bcasted_inputs = torch.broadcast_tensors(*inputs)
raveled_inputs = [inp.ravel() for inp in bcasted_inputs]

out_shape = bcasted_inputs[0].size()
out_size = out_shape.numel()
raveled_outputs = [torch.zeros(out_size) for out in node.outputs]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there no torch.empty?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mb; i had an old version of torch on my machine (2.2) which didn't have it, but 2.3+ does. Reverted to torch.empty


for i in range(out_size):
core_outs = base_fn(*(inp[i] for inp in raveled_inputs))
if n_outputs == 1:
raveled_outputs[0][i] = core_outs

Check warning on line 206 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L206

Added line #L206 was not covered by tests
else:
for o in range(n_outputs):
raveled_outputs[o][i] = core_outs[o]

outputs = tuple(out.view(out_shape) for out in raveled_outputs)
if n_outputs == 1:
return outputs[0]

Check warning on line 213 in pytensor/link/pytorch/dispatch/elemwise.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/elemwise.py#L213

Added line #L213 was not covered by tests
else:
return outputs

return elemwise_fn
36 changes: 36 additions & 0 deletions pytensor/link/pytorch/dispatch/scalar.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import importlib

import torch
import torch.compiler

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
Cast,
ScalarOp,
)
from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Softplus


Expand Down Expand Up @@ -62,3 +64,37 @@
@pytorch_funcify.register(Softplus)
def pytorch_funcify_Softplus(op, node, **kwargs):
return torch.nn.Softplus()


@pytorch_funcify.register(ScalarLoop)
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
update = pytorch_funcify(op.fgraph, **kwargs)
state_length = op.nout
if op.is_while:

def scalar_loop(steps, *start_and_constants):
carry, constants = (

Check warning on line 76 in pytensor/link/pytorch/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/scalar.py#L76

Added line #L76 was not covered by tests
start_and_constants[:state_length],
start_and_constants[state_length:],
)
done = True

Check warning on line 80 in pytensor/link/pytorch/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/scalar.py#L80

Added line #L80 was not covered by tests
for _ in range(steps):
*carry, done = update(*carry, *constants)
if torch.any(done):
break
return *carry, done
else:

def scalar_loop(steps, *start_and_constants):
carry, constants = (

Check warning on line 89 in pytensor/link/pytorch/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/scalar.py#L89

Added line #L89 was not covered by tests
start_and_constants[:state_length],
start_and_constants[state_length:],
)
for _ in range(steps):
carry = update(*carry, *constants)
if len(node.outputs) == 1:
return carry[0]
else:
return carry

Check warning on line 98 in pytensor/link/pytorch/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/scalar.py#L98

Added line #L98 was not covered by tests

return scalar_loop
17 changes: 15 additions & 2 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,19 @@
super().__init__(*args, **kwargs)
self.gen_functors = []

def input_filter(self, inp):
from pytensor.link.pytorch.dispatch import pytorch_typify

Check warning on line 13 in pytensor/link/pytorch/linker.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/linker.py#L13

Added line #L13 was not covered by tests

return pytorch_typify(inp)

Check warning on line 15 in pytensor/link/pytorch/linker.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/linker.py#L15

Added line #L15 was not covered by tests

def output_filter(self, var, out):
from torch import is_tensor

Check warning on line 18 in pytensor/link/pytorch/linker.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/linker.py#L18

Added line #L18 was not covered by tests

if is_tensor(out):
return out.cpu()

Check warning on line 21 in pytensor/link/pytorch/linker.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/linker.py#L21

Added line #L21 was not covered by tests
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will probably create conflict when one of my other PRs gets merged as an FYI.

else:
return out

Check warning on line 23 in pytensor/link/pytorch/linker.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/linker.py#L23

Added line #L23 was not covered by tests

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify

Expand Down Expand Up @@ -77,11 +90,11 @@
self.gen_functors = []

# Torch does not accept numpy inputs and may return GPU objects
def fn(*inputs, inner_fn=inner_fn):
def create_outputs(*inputs, inner_fn=inner_fn):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why the new name? Seems less clear

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this fn was shadowing a local variable fn so i just renamed one of them.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure but can we use a different name. This doesn't "create_outputs" it converts inputs to torch tensors and outputs back to pytensor-compatible types

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure thing - I can also just keep the shadowing lol. It's not the end of the world.

From your description I would probably have called it convert_types or smth.

Copy link
Member

@ricardoV94 ricardoV94 Dec 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can also put it inside the wrapper __call__ I guess?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that makes sense too.

outs = inner_fn(*(pytorch_typify(inp) for inp in inputs))
return tuple(out.cpu().numpy() for out in outs)

return fn
return create_outputs

def create_thunk_inputs(self, storage_map):
thunk_inputs = []
Expand Down
86 changes: 86 additions & 0 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np
import pytest

import pytensor.tensor as pt
import pytensor.tensor.basic as ptb
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
Expand All @@ -17,7 +18,10 @@
from pytensor.ifelse import ifelse
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.raise_op import CheckAndRaise
from pytensor.scalar import float64, int64
from pytensor.scalar.loop import ScalarLoop
from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.type import matrices, matrix, scalar, vector


Expand Down Expand Up @@ -385,3 +389,85 @@ def test_pytorch_softplus():
out = softplus(x)
f = FunctionGraph([x], [out])
compare_pytorch_and_py(f, [np.random.rand(3)])


def test_ScalarLoop():
n_steps = int64("n_steps")
x0 = float64("x0")
const = float64("const")
x = x0 + const

op = ScalarLoop(init=[x0], constant=[const], update=[x])
x = op(n_steps, x0, const)

fn = function([n_steps, x0, const], x, mode=pytorch_mode)
np.testing.assert_allclose(fn(5, 0, 1), 5)
np.testing.assert_allclose(fn(5, 0, 2), 10)
np.testing.assert_allclose(fn(4, 3, -1), -1)


def test_ScalarLoop_while():
n_steps = int64("n_steps")
x0 = float64("x0")
x = x0 + 1
until = x >= 10

op = ScalarLoop(init=[x0], update=[x], until=until)
fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode)
for res, expected in zip(
[fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)],
[[10, True], [10, True], [6, False]],
strict=True,
):
np.testing.assert_allclose(res[0], np.array(expected[0]))
np.testing.assert_allclose(res[1], np.array(expected[1]))


def test_ScalarLoop_Elemwise_single_carries():
n_steps = int64("n_steps")
x0 = float64("x0")
x = x0 * 2
until = x >= 10

scalarop = ScalarLoop(init=[x0], update=[x], until=until)
op = Elemwise(scalarop)

n_steps = pt.scalar("n_steps", dtype="int32")
x0 = pt.vector("x0", dtype="float32")
state, done = op(n_steps, x0)

f = FunctionGraph([n_steps, x0], [state, done])
args = [
np.array(10).astype("int32"),
np.arange(0, 5).astype("float32"),
]
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
)


def test_ScalarLoop_Elemwise_multi_carries():
n_steps = int64("n_steps")
x0 = float64("x0")
x1 = float64("x1")
x = x0 * 2
x1_n = x1 * 3
until = x >= 10

scalarop = ScalarLoop(init=[x0, x1], update=[x, x1_n], until=until)
op = Elemwise(scalarop)

n_steps = pt.scalar("n_steps", dtype="int32")
x0 = pt.vector("x0", dtype="float32")
x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1))
*states, done = op(n_steps, x0, x1)

f = FunctionGraph([n_steps, x0, x1], [*states, done])
args = [
np.array(10).astype("int32"),
np.arange(0, 5).astype("float32"),
np.random.rand(7, 3, 1).astype("float32"),
]
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
)
Loading