From 7804b904fd15a7fc8cb5844263a3d38a0abbb263 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Wed, 31 Jul 2024 20:32:14 -0700 Subject: [PATCH 01/26] Add for loop based scalar loop --- pytensor/link/pytorch/dispatch/scalar.py | 22 +++++++++++++++++++++- tests/link/pytorch/test_basic.py | 16 ++++++++++++++++ 2 files changed, 37 insertions(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 1416e58f55..e25133457b 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -8,6 +8,7 @@ ScalarOp, ) from pytensor.scalar.math import Softplus +from pytensor.scalar.loop import ScalarLoop @pytorch_funcify.register(ScalarOp) @@ -58,7 +59,26 @@ def cast(x): return cast - @pytorch_funcify.register(Softplus) def pytorch_funcify_Softplus(op, node, **kwargs): return torch.nn.Softplus() + +@pytorch_funcify.register(ScalarLoop) +def pytorch_funicify_ScalarLoop(op, node, **kwargs): + update = pytorch_funcify(op.fgraph) + + def inner(steps, start, constant, update=update, is_while=op.is_while): + # easiest way to do it is to loop + c = start + for i in range(steps): + outs = update(c, constant) + if is_while: + n, done = outs + if done: + return n + c = n + else: + c = outs[0] + return c + + return inner diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 83249d021b..f82dee3d15 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -17,6 +17,8 @@ from pytensor.ifelse import ifelse from pytensor.link.pytorch.linker import PytorchLinker from pytensor.raise_op import CheckAndRaise +from pytensor.scalar import float64, int64 +from pytensor.scalar.loop import ScalarLoop from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus from pytensor.tensor.type import matrices, matrix, scalar, vector @@ -388,3 +390,17 @@ def test_pytorch_softplus(): out = softplus(x) f = FunctionGraph([x], [out]) compare_pytorch_and_py(f, [np.random.rand(3)]) + +def test_ScalarLoop(): + n_steps = int64("n_steps") + x0 = float64("x0") + const = float64("const") + x = x0 + const + + op = ScalarLoop(init=[x0], constant=[const], update=[x]) + x = op(n_steps, x0, const) + + fn = function([n_steps, x0, const], x, mode=pytorch_mode) + np.testing.assert_allclose(fn(5, 0, 1), 5) + np.testing.assert_allclose(fn(5, 0, 2), 10) + np.testing.assert_allclose(fn(4, 3, -1), -1) From 12569f8209905a1993bec3fe49231e6d1ba30936 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sun, 11 Aug 2024 11:50:07 -0700 Subject: [PATCH 02/26] Pass all loop tests --- pytensor/link/pytorch/dispatch/scalar.py | 36 +++++++++++++++--------- 1 file changed, 22 insertions(+), 14 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index e25133457b..ef849c08f6 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -66,19 +66,27 @@ def pytorch_funcify_Softplus(op, node, **kwargs): @pytorch_funcify.register(ScalarLoop) def pytorch_funicify_ScalarLoop(op, node, **kwargs): update = pytorch_funcify(op.fgraph) - - def inner(steps, start, constant, update=update, is_while=op.is_while): - # easiest way to do it is to loop - c = start - for i in range(steps): - outs = update(c, constant) - if is_while: - n, done = outs + if op.is_while: + + def scalar_loop(steps, *start_and_constants): + *carry, constants = start_and_constants + constants = constants.unsqueeze(0) + done = True + for _ in range(steps): + *carry, done = update(*carry, *constants) + constants = start_and_constants[len(carry) :] if done: - return n - c = n - else: - c = outs[0] - return c + break + return torch.stack((*carry, done)) + else: + + def scalar_loop(*args): + steps, *start_and_constants = args + *carry, constants = start_and_constants + constants = constants.unsqueeze(0) + for i in range(steps): + carry = update(*carry, *constants) + constants = start_and_constants[len(carry) :] + return torch.stack(carry) - return inner + return scalar_loop From 8eff3fe30d1b20b25de33fe635721babff6e7fa4 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sun, 11 Aug 2024 21:02:10 -0700 Subject: [PATCH 03/26] Fetch constants from op --- pytensor/link/pytorch/dispatch/scalar.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index ef849c08f6..62c266487e 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -66,15 +66,17 @@ def pytorch_funcify_Softplus(op, node, **kwargs): @pytorch_funcify.register(ScalarLoop) def pytorch_funicify_ScalarLoop(op, node, **kwargs): update = pytorch_funcify(op.fgraph) + state_length = op.nout if op.is_while: def scalar_loop(steps, *start_and_constants): - *carry, constants = start_and_constants - constants = constants.unsqueeze(0) + carry, constants = ( + start_and_constants[:state_length], + start_and_constants[state_length:], + ) done = True for _ in range(steps): *carry, done = update(*carry, *constants) - constants = start_and_constants[len(carry) :] if done: break return torch.stack((*carry, done)) @@ -82,11 +84,12 @@ def scalar_loop(steps, *start_and_constants): def scalar_loop(*args): steps, *start_and_constants = args - *carry, constants = start_and_constants - constants = constants.unsqueeze(0) - for i in range(steps): + carry, constants = ( + start_and_constants[:state_length], + start_and_constants[state_length:], + ) + for _ in range(steps): carry = update(*carry, *constants) - constants = start_and_constants[len(carry) :] return torch.stack(carry) return scalar_loop From e0bbde8ec46b589d9a9e50daa2d5f1e978aab867 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Sun, 8 Sep 2024 16:33:06 -0700 Subject: [PATCH 04/26] Add while loop test --- tests/link/pytorch/test_basic.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index f82dee3d15..ba22e6eaba 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -404,3 +404,16 @@ def test_ScalarLoop(): np.testing.assert_allclose(fn(5, 0, 1), 5) np.testing.assert_allclose(fn(5, 0, 2), 10) np.testing.assert_allclose(fn(4, 3, -1), -1) + + +def test_ScalarLoop_while(): + n_steps = int64("n_steps") + x0 = float64("x0") + x = x0 + 1 + until = x >= 10 + + op = ScalarLoop(init=[x0], update=[x], until=until) + fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode) + np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True]) + np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True]) + np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False]) From ae1c9da6f880f5aee696f984aa862d1c40ee3d30 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Thu, 12 Sep 2024 16:51:30 -0700 Subject: [PATCH 05/26] Fix while loop and nasty stack over dtypes --- pytensor/link/pytorch/dispatch/scalar.py | 2 +- tests/link/pytorch/test_basic.py | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 62c266487e..c2f9e111e1 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -79,7 +79,7 @@ def scalar_loop(steps, *start_and_constants): *carry, done = update(*carry, *constants) if done: break - return torch.stack((*carry, done)) + return torch.stack(carry), torch.tensor([done]) else: def scalar_loop(*args): diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index ba22e6eaba..6cca8be984 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -414,6 +414,9 @@ def test_ScalarLoop_while(): op = ScalarLoop(init=[x0], update=[x], until=until) fn = function([n_steps, x0], op(n_steps, x0), mode=pytorch_mode) - np.testing.assert_allclose(fn(n_steps=20, x0=0), [10, True]) - np.testing.assert_allclose(fn(n_steps=20, x0=1), [10, True]) - np.testing.assert_allclose(fn(n_steps=5, x0=1), [6, False]) + for res, expected in zip( + [fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)], + [[10, True], [10, True], [6, False]], + ): + np.testing.assert_allclose(res[0], np.array(expected[0])) + np.testing.assert_allclose(res[1], np.array(expected[1])) From 2844bc46f393fc976aecc20e7a470a60ea1604e4 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Fri, 13 Sep 2024 09:59:38 -0700 Subject: [PATCH 06/26] Disable compile here based on CI result --- pytensor/link/pytorch/dispatch/scalar.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index c2f9e111e1..8a635c0760 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -1,6 +1,7 @@ import importlib import torch +import torch.compiler from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.scalar.basic import ( @@ -92,4 +93,4 @@ def scalar_loop(*args): carry = update(*carry, *constants) return torch.stack(carry) - return scalar_loop + return torch.compiler.disable(scalar_loop) From 39ff3de0cf43fd60cd4bed6328c927224bbd0b28 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Fri, 13 Sep 2024 10:42:40 -0700 Subject: [PATCH 07/26] Fix mypy signature --- pytensor/link/pytorch/dispatch/scalar.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 8a635c0760..1a24817370 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -83,8 +83,7 @@ def scalar_loop(steps, *start_and_constants): return torch.stack(carry), torch.tensor([done]) else: - def scalar_loop(*args): - steps, *start_and_constants = args + def scalar_loop(steps, *start_and_constants): carry, constants = ( start_and_constants[:state_length], start_and_constants[state_length:], From 714759c1dd5ec7577b94454d6efc751c996013fe Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Thu, 19 Sep 2024 13:49:02 -0700 Subject: [PATCH 08/26] Remove unnecessary torch stack --- pytensor/link/pytorch/dispatch/scalar.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 1a24817370..31ad0d0713 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -80,7 +80,10 @@ def scalar_loop(steps, *start_and_constants): *carry, done = update(*carry, *constants) if done: break - return torch.stack(carry), torch.tensor([done]) + if len(node.outputs) == 2: + return carry[0], done + else: + return carry, done else: def scalar_loop(steps, *start_and_constants): @@ -90,6 +93,9 @@ def scalar_loop(steps, *start_and_constants): ) for _ in range(steps): carry = update(*carry, *constants) - return torch.stack(carry) + if len(node.outputs) == 1: + return carry[0] + else: + return carry return torch.compiler.disable(scalar_loop) From 623dfbe04fbd077ab09096da097efd1640e749d4 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Thu, 19 Sep 2024 13:49:24 -0700 Subject: [PATCH 09/26] Only call .cpu when necessary --- pytensor/link/pytorch/linker.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index ec26fd252f..e710dcb659 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -1,6 +1,8 @@ import copy from typing import Any +from torch import is_tensor + from pytensor.graph.basic import Variable from pytensor.link.basic import JITLinker from pytensor.link.utils import unique_name_generator @@ -19,7 +21,10 @@ def input_filter(self, inp: Any) -> Any: return pytorch_typify(inp) def output_filter(self, var: Variable, out: Any) -> Any: - return out.cpu() + if is_tensor(out): + return out.cpu() + else: + return out def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): from pytensor.link.pytorch.dispatch import pytorch_funcify From e06994fb777fd0f66c7185fa5f89995391b4570b Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Thu, 19 Sep 2024 13:50:43 -0700 Subject: [PATCH 10/26] Recursive false for torch compiler --- pytensor/link/pytorch/dispatch/scalar.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 31ad0d0713..30b1c24c21 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -98,4 +98,4 @@ def scalar_loop(steps, *start_and_constants): else: return carry - return torch.compiler.disable(scalar_loop) + return torch.compiler.disable(scalar_loop, recursive=False) From 977d98d80a403adb1f98f3eefe56f10a9fbeaf11 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Thu, 19 Sep 2024 15:00:23 -0700 Subject: [PATCH 11/26] Add elemwise test --- pytensor/link/pytorch/dispatch/scalar.py | 2 +- tests/link/pytorch/test_basic.py | 31 ++++++++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 30b1c24c21..f6088e6967 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -78,7 +78,7 @@ def scalar_loop(steps, *start_and_constants): done = True for _ in range(steps): *carry, done = update(*carry, *constants) - if done: + if torch.any(done): break if len(node.outputs) == 2: return carry[0], done diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 6cca8be984..b956a1cf00 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -20,6 +20,7 @@ from pytensor.scalar import float64, int64 from pytensor.scalar.loop import ScalarLoop from pytensor.tensor import alloc, arange, as_tensor, empty, expit, eye, softplus +from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.type import matrices, matrix, scalar, vector @@ -420,3 +421,33 @@ def test_ScalarLoop_while(): ): np.testing.assert_allclose(res[0], np.array(expected[0])) np.testing.assert_allclose(res[1], np.array(expected[1])) + +def test_pytorch_OpFromGraph(): + x, y, z = matrices("xyz") + ofg_1 = OpFromGraph([x, y], [x + y]) + ofg_2 = OpFromGraph([x, y], [x * y, x - y]) + + o1, o2 = ofg_2(y, z) + out = ofg_1(x, o1) + o2 + + xv = np.ones((2, 2), dtype=config.floatX) + yv = np.ones((2, 2), dtype=config.floatX) * 3 + zv = np.ones((2, 2), dtype=config.floatX) * 5 + + f = FunctionGraph([x, y, z], [out]) + compare_pytorch_and_py(f, [xv, yv, zv]) + + +def test_ScalarLoop_Elemwise(): + n_steps = int64("n_steps") + x0 = float64("x0") + x = x0 * 2 + until = x >= 10 + + op = ScalarLoop(init=[x0], update=[x], until=until) + fn = function([n_steps, x0], Elemwise(op)(n_steps, x0), mode=pytorch_mode) + + states, dones = fn(10, np.array(range(5))) + + np.testing.assert_allclose(states, [0, 4, 8, 12, 16]) + np.testing.assert_allclose(dones, [False, False, False, True, True]) From 07e45206a89269cd13dd6b5a7c6b278ec8203dba Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Thu, 19 Sep 2024 15:08:51 -0700 Subject: [PATCH 12/26] Late import torch --- pytensor/link/pytorch/linker.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index e710dcb659..491bbc8fbb 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -1,8 +1,6 @@ import copy from typing import Any -from torch import is_tensor - from pytensor.graph.basic import Variable from pytensor.link.basic import JITLinker from pytensor.link.utils import unique_name_generator @@ -21,6 +19,8 @@ def input_filter(self, inp: Any) -> Any: return pytorch_typify(inp) def output_filter(self, var: Variable, out: Any) -> Any: + from torch import is_tensor + if is_tensor(out): return out.cpu() else: From 561301b45eb9f98f6da6b1efe8eb143a791b9186 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sat, 26 Oct 2024 22:35:26 -0700 Subject: [PATCH 13/26] Do iteration instead of vmap for elemwise --- pytensor/link/pytorch/dispatch/elemwise.py | 30 ++++++++++++++++++++++ tests/link/pytorch/test_basic.py | 20 +++++++++++---- 2 files changed, 45 insertions(+), 5 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 79ca5beec1..ffe10d680a 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -1,8 +1,10 @@ import importlib +from itertools import chain import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify +from pytensor.scalar import ScalarLoop from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad @@ -33,6 +35,33 @@ def elemwise_fn(*inputs): Elemwise._check_runtime_broadcast(node, inputs) return base_fn(*inputs) + elif isinstance(scalar_op, ScalarLoop): + # note: scalarloop + elemwise is too common + # to not work, but @1031, vmap won't allow it. + # Instead, we will just successively unbind + def elemwise_fn(*inputs): + Elemwise._check_runtime_broadcast(node, inputs) + shaped_inputs = torch.broadcast_tensors(*inputs) + expected_size = shaped_inputs[0].numel() + final_inputs = [s.clone() for s in shaped_inputs] + for _ in range(shaped_inputs[0].dim() - 1): + for i, _ in enumerate(shaped_inputs): + layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]]) + final_inputs[i] = list(layer) + + # make sure we still have the same number of things + assert len(final_inputs) == len(shaped_inputs) + + # make sure each group of things are the expected size + assert all(len(x) == expected_size for x in final_inputs) + + # make sure they are all single elements + assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor) + res = [base_fn(*args) for args in zip(*final_inputs)] + states = torch.stack(tuple(out[0] for out in res)) + done = torch.stack(tuple(out[1] for out in res)) + return states, done + else: def elemwise_fn(*inputs): @@ -42,6 +71,7 @@ def elemwise_fn(*inputs): for _ in range(broadcast_inputs[0].dim()): ufunc = torch.vmap(ufunc) return ufunc(*broadcast_inputs) + return base_fn(*inputs) return elemwise_fn diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index b956a1cf00..1850d9403c 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -4,6 +4,7 @@ import numpy as np import pytest +import pytensor.tensor as pt import pytensor.tensor.basic as ptb from pytensor.compile.builders import OpFromGraph from pytensor.compile.function import function @@ -444,10 +445,19 @@ def test_ScalarLoop_Elemwise(): x = x0 * 2 until = x >= 10 - op = ScalarLoop(init=[x0], update=[x], until=until) - fn = function([n_steps, x0], Elemwise(op)(n_steps, x0), mode=pytorch_mode) + scalarop = ScalarLoop(init=[x0], update=[x], until=until) + op = Elemwise(scalarop) + + n_steps = pt.scalar("n_steps", dtype="int32") + x0 = pt.vector("x0", dtype="float32") + state, done = op(n_steps, x0) + + fn = function([n_steps, x0], [state, done], mode=pytorch_mode) + py_fn = function([n_steps, x0], [state, done]) - states, dones = fn(10, np.array(range(5))) + args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")] + torch_states, torch_dones = fn(*args) + py_states, py_dones = py_fn(*args) - np.testing.assert_allclose(states, [0, 4, 8, 12, 16]) - np.testing.assert_allclose(dones, [False, False, False, True, True]) + np.testing.assert_allclose(torch_states, py_states) + np.testing.assert_allclose(torch_dones, py_dones) From e28c3e209c1043145b7dc710996ce5a500701f53 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Sat, 2 Nov 2024 16:57:57 -0700 Subject: [PATCH 14/26] Clean up and add description --- pytensor/link/pytorch/dispatch/elemwise.py | 80 +++++++++++++++------- 1 file changed, 55 insertions(+), 25 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index ffe10d680a..1c7e3cc254 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -13,6 +13,7 @@ @pytorch_funcify.register(Elemwise) def pytorch_funcify_Elemwise(op, node, **kwargs): scalar_op = op.scalar_op + base_fn = pytorch_funcify(scalar_op, node=node, **kwargs) def check_special_scipy(func_name): @@ -36,31 +37,7 @@ def elemwise_fn(*inputs): return base_fn(*inputs) elif isinstance(scalar_op, ScalarLoop): - # note: scalarloop + elemwise is too common - # to not work, but @1031, vmap won't allow it. - # Instead, we will just successively unbind - def elemwise_fn(*inputs): - Elemwise._check_runtime_broadcast(node, inputs) - shaped_inputs = torch.broadcast_tensors(*inputs) - expected_size = shaped_inputs[0].numel() - final_inputs = [s.clone() for s in shaped_inputs] - for _ in range(shaped_inputs[0].dim() - 1): - for i, _ in enumerate(shaped_inputs): - layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]]) - final_inputs[i] = list(layer) - - # make sure we still have the same number of things - assert len(final_inputs) == len(shaped_inputs) - - # make sure each group of things are the expected size - assert all(len(x) == expected_size for x in final_inputs) - - # make sure they are all single elements - assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor) - res = [base_fn(*args) for args in zip(*final_inputs)] - states = torch.stack(tuple(out[0] for out in res)) - done = torch.stack(tuple(out[1] for out in res)) - return states, done + return elemwise_scalar_loop(base_fn, op, node, **kwargs) else: @@ -206,3 +183,56 @@ def softmax_grad(dy, sm): return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm return softmax_grad + + +def elemwise_scalar_loop(base_fn, op, node, **kwargs): + """ + ScalarLoop + Elemwise is too common + to not work, but @1031, vmap won't allow it. + Instead, we can do the following strategy + 1. `.unbind(dim)` will return a list of tensors + representing `dim` but "unwrapped". e.x. + ``` + t = torch.ones(3, 4, 2) + len(t.unbind(0)) == 3 + t[0].shape == torch.Size[4, 2] + 2. If we successfully apply, the length of the list will grow + by the next dimension in the tensor if we flatten the previous + dimension result + ``` + inputs = [torch.ones(3, 4, 2)] + level_1 = chain.from_iterable(t.unbind(0) for t in inputs) + level_2 = chain.from_iterable(t.unbind(0) for t in level_1) + len(level_2) == 3 * 4 + ``` + 3. Eventually we'll reach single dimension tensors. At that point + we can iterate over each input in an element by element manner + and call some function + + For scalar loop, we need to broadcast the tensors so all + the necessary values are repeated, and we "evenly" iterate through everything + """ + + def elemwise_fn(*inputs): + Elemwise._check_runtime_broadcast(node, inputs) + shaped_inputs = torch.broadcast_tensors(*inputs) + expected_size = shaped_inputs[0].numel() + final_inputs = [s.clone() for s in shaped_inputs] + for _ in range(shaped_inputs[0].dim() - 1): + for i, _ in enumerate(shaped_inputs): + layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]]) + final_inputs[i] = list(layer) + + # make sure we still have the same number of things + assert len(final_inputs) == len(shaped_inputs) + + # make sure each group of things are the expected size + assert all(len(x) == expected_size for x in final_inputs) + + # make sure they are all single elements + assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor) + res = [base_fn(*args) for args in zip(*final_inputs)] + + return [torch.stack(tuple(out[i] for out in res)) for i in range(len(res[0]))] + + return elemwise_fn From fb905006ce341040596e8f358f1fb8562d5b7679 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Sat, 2 Nov 2024 17:59:17 -0700 Subject: [PATCH 15/26] Add unit test to verify iteration --- tests/link/pytorch/test_basic.py | 35 ++++++++++++++++++++++++-------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 1850d9403c..434ee690dd 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -1,5 +1,7 @@ from collections.abc import Callable, Iterable from functools import partial +from itertools import repeat, starmap +from unittest.mock import MagicMock, call, patch import numpy as np import pytest @@ -452,12 +454,29 @@ def test_ScalarLoop_Elemwise(): x0 = pt.vector("x0", dtype="float32") state, done = op(n_steps, x0) - fn = function([n_steps, x0], [state, done], mode=pytorch_mode) - py_fn = function([n_steps, x0], [state, done]) - + f = FunctionGraph([n_steps, x0], [state, done]) args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")] - torch_states, torch_dones = fn(*args) - py_states, py_dones = py_fn(*args) - - np.testing.assert_allclose(torch_states, py_states) - np.testing.assert_allclose(torch_dones, py_dones) + compare_pytorch_and_py(f, args) + + +torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise") + + +@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]]) +@patch("pytensor.link.pytorch.dispatch.elemwise.Elemwise") +def test_ScalarLoop_Elemwise_iteration_logic(_, input_shapes): + args = [torch.ones(*s) for s in input_shapes[:-1]] + [ + torch.zeros(*input_shapes[-1]) + ] + mock_inner_func = MagicMock() + ret_value = torch.rand(2, 2).unbind(0) + mock_inner_func.f.return_value = ret_value + elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None) + result = elemwise_fn(*args) + for actual, expected in zip(ret_value, result): + assert torch.all(torch.eq(*torch.broadcast_tensors(actual, expected))) + np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0])) + + expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0) + expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count)) + mock_inner_func.f.assert_has_calls(expected_calls) From cd678ef1284944dab44d25593b4d7e2944686e27 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sun, 24 Nov 2024 19:17:49 -0800 Subject: [PATCH 16/26] Refactor to ravel method --- pytensor/link/pytorch/dispatch/elemwise.py | 68 ++++++++-------------- tests/link/pytorch/test_basic.py | 41 +------------ 2 files changed, 25 insertions(+), 84 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 1c7e3cc254..db37ed0343 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -1,6 +1,4 @@ import importlib -from itertools import chain - import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify @@ -189,50 +187,32 @@ def elemwise_scalar_loop(base_fn, op, node, **kwargs): """ ScalarLoop + Elemwise is too common to not work, but @1031, vmap won't allow it. - Instead, we can do the following strategy - 1. `.unbind(dim)` will return a list of tensors - representing `dim` but "unwrapped". e.x. - ``` - t = torch.ones(3, 4, 2) - len(t.unbind(0)) == 3 - t[0].shape == torch.Size[4, 2] - 2. If we successfully apply, the length of the list will grow - by the next dimension in the tensor if we flatten the previous - dimension result - ``` - inputs = [torch.ones(3, 4, 2)] - level_1 = chain.from_iterable(t.unbind(0) for t in inputs) - level_2 = chain.from_iterable(t.unbind(0) for t in level_1) - len(level_2) == 3 * 4 - ``` - 3. Eventually we'll reach single dimension tensors. At that point - we can iterate over each input in an element by element manner - and call some function - - For scalar loop, we need to broadcast the tensors so all - the necessary values are repeated, and we "evenly" iterate through everything + Instead, we can ravel all the inputs, broadcasted + according to torch """ + n_outputs = len(node.outputs) + def elemwise_fn(*inputs): - Elemwise._check_runtime_broadcast(node, inputs) - shaped_inputs = torch.broadcast_tensors(*inputs) - expected_size = shaped_inputs[0].numel() - final_inputs = [s.clone() for s in shaped_inputs] - for _ in range(shaped_inputs[0].dim() - 1): - for i, _ in enumerate(shaped_inputs): - layer = chain.from_iterable([s.unbind(0) for s in final_inputs[i]]) - final_inputs[i] = list(layer) - - # make sure we still have the same number of things - assert len(final_inputs) == len(shaped_inputs) - - # make sure each group of things are the expected size - assert all(len(x) == expected_size for x in final_inputs) - - # make sure they are all single elements - assert all(len(x.shape) == 0 for tensor in final_inputs for x in tensor) - res = [base_fn(*args) for args in zip(*final_inputs)] - - return [torch.stack(tuple(out[i] for out in res)) for i in range(len(res[0]))] + bcasted_inputs = torch.broadcast_tensors(*inputs) + raveled_inputs = [inp.ravel() for inp in bcasted_inputs] + + out_shape = bcasted_inputs[0].size() + out_size = out_shape.numel() + raveled_outputs = [torch.zeros(out_size) for out in node.outputs] + + for i in range(out_size): + core_outs = base_fn(*(inp[i] for inp in raveled_inputs)) + if n_outputs == 1: + raveled_outputs[0][i] = core_outs + else: + for o in range(n_outputs): + raveled_outputs[o][i] = core_outs[o] + + outputs = tuple(out.view(out_shape) for out in raveled_outputs) + if n_outputs == 1: + return outputs[0] + else: + return outputs return elemwise_fn diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 434ee690dd..04fdfac6d0 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -1,7 +1,5 @@ from collections.abc import Callable, Iterable from functools import partial -from itertools import repeat, starmap -from unittest.mock import MagicMock, call, patch import numpy as np import pytest @@ -421,25 +419,11 @@ def test_ScalarLoop_while(): for res, expected in zip( [fn(n_steps=20, x0=0), fn(n_steps=20, x0=1), fn(n_steps=5, x0=1)], [[10, True], [10, True], [6, False]], + strict=True, ): np.testing.assert_allclose(res[0], np.array(expected[0])) np.testing.assert_allclose(res[1], np.array(expected[1])) -def test_pytorch_OpFromGraph(): - x, y, z = matrices("xyz") - ofg_1 = OpFromGraph([x, y], [x + y]) - ofg_2 = OpFromGraph([x, y], [x * y, x - y]) - - o1, o2 = ofg_2(y, z) - out = ofg_1(x, o1) + o2 - - xv = np.ones((2, 2), dtype=config.floatX) - yv = np.ones((2, 2), dtype=config.floatX) * 3 - zv = np.ones((2, 2), dtype=config.floatX) * 5 - - f = FunctionGraph([x, y, z], [out]) - compare_pytorch_and_py(f, [xv, yv, zv]) - def test_ScalarLoop_Elemwise(): n_steps = int64("n_steps") @@ -457,26 +441,3 @@ def test_ScalarLoop_Elemwise(): f = FunctionGraph([n_steps, x0], [state, done]) args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")] compare_pytorch_and_py(f, args) - - -torch_elemwise = pytest.importorskip("pytensor.link.pytorch.dispatch.elemwise") - - -@pytest.mark.parametrize("input_shapes", [[(5, 1, 1, 8), (3, 1, 1), (8,)]]) -@patch("pytensor.link.pytorch.dispatch.elemwise.Elemwise") -def test_ScalarLoop_Elemwise_iteration_logic(_, input_shapes): - args = [torch.ones(*s) for s in input_shapes[:-1]] + [ - torch.zeros(*input_shapes[-1]) - ] - mock_inner_func = MagicMock() - ret_value = torch.rand(2, 2).unbind(0) - mock_inner_func.f.return_value = ret_value - elemwise_fn = torch_elemwise.elemwise_scalar_loop(mock_inner_func.f, None, None) - result = elemwise_fn(*args) - for actual, expected in zip(ret_value, result): - assert torch.all(torch.eq(*torch.broadcast_tensors(actual, expected))) - np.testing.assert_equal(mock_inner_func.f.call_count, len(result[0])) - - expected_args = torch.FloatTensor([1.0] * (len(input_shapes) - 1) + [0.0]).unbind(0) - expected_calls = starmap(call, repeat(expected_args, mock_inner_func.f.call_count)) - mock_inner_func.f.assert_has_calls(expected_calls) From 1865de9dc17c001a547e55efc2a6a6a84d0edf34 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sun, 24 Nov 2024 19:19:31 -0800 Subject: [PATCH 17/26] Fix unpacking Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> --- pytensor/link/pytorch/dispatch/scalar.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index f6088e6967..0ed4597809 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -80,10 +80,7 @@ def scalar_loop(steps, *start_and_constants): *carry, done = update(*carry, *constants) if torch.any(done): break - if len(node.outputs) == 2: - return carry[0], done - else: - return carry, done + return *carry, done else: def scalar_loop(steps, *start_and_constants): From 1ffd7c60ecdb6ccb89b81e644537a3805d1c7399 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sun, 24 Nov 2024 19:21:04 -0800 Subject: [PATCH 18/26] Fix comment --- pytensor/link/pytorch/dispatch/elemwise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index db37ed0343..58de8e8a34 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -186,8 +186,8 @@ def softmax_grad(dy, sm): def elemwise_scalar_loop(base_fn, op, node, **kwargs): """ ScalarLoop + Elemwise is too common - to not work, but @1031, vmap won't allow it. - Instead, we can ravel all the inputs, broadcasted + to not work, but https://github.com/pymc-devs/pytensor/issues/1031, + vmap won't allow it. Instead, we can ravel all the inputs, broadcasted according to torch """ From fd2f192e619d0b8a294b3891ffd48c470f7f4b47 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sun, 24 Nov 2024 19:22:25 -0800 Subject: [PATCH 19/26] Remove extra return --- pytensor/link/pytorch/dispatch/elemwise.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 58de8e8a34..79183da08d 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -46,7 +46,6 @@ def elemwise_fn(*inputs): for _ in range(broadcast_inputs[0].dim()): ufunc = torch.vmap(ufunc) return ufunc(*broadcast_inputs) - return base_fn(*inputs) return elemwise_fn From 7027c4c7210a48f5a84fd13963ec94832e84b9d0 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sun, 24 Nov 2024 20:19:10 -0800 Subject: [PATCH 20/26] Update test --- pytensor/link/pytorch/dispatch/scalar.py | 3 ++- tests/link/pytorch/test_basic.py | 21 +++++++++++++++------ 2 files changed, 17 insertions(+), 7 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 0ed4597809..e5bf55fbf6 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -64,6 +64,7 @@ def cast(x): def pytorch_funcify_Softplus(op, node, **kwargs): return torch.nn.Softplus() + @pytorch_funcify.register(ScalarLoop) def pytorch_funicify_ScalarLoop(op, node, **kwargs): update = pytorch_funcify(op.fgraph) @@ -80,7 +81,7 @@ def scalar_loop(steps, *start_and_constants): *carry, done = update(*carry, *constants) if torch.any(done): break - return *carry, done + return *carry, done else: def scalar_loop(steps, *start_and_constants): diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 04fdfac6d0..350da40188 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -428,16 +428,25 @@ def test_ScalarLoop_while(): def test_ScalarLoop_Elemwise(): n_steps = int64("n_steps") x0 = float64("x0") + x1 = float64("x1") x = x0 * 2 + x1_n = x1 * 3 until = x >= 10 - scalarop = ScalarLoop(init=[x0], update=[x], until=until) + scalarop = ScalarLoop(init=[x0, x1], update=[x, x1_n], until=until) op = Elemwise(scalarop) n_steps = pt.scalar("n_steps", dtype="int32") x0 = pt.vector("x0", dtype="float32") - state, done = op(n_steps, x0) - - f = FunctionGraph([n_steps, x0], [state, done]) - args = [np.array(10).astype("int32"), np.arange(0, 5).astype("float32")] - compare_pytorch_and_py(f, args) + x1 = pt.tensor("c0", dtype="float32", shape=(7, 3, 1)) + *states, done = op(n_steps, x0, x1) + + f = FunctionGraph([n_steps, x0, x1], [*states, done]) + args = [ + np.array(10).astype("int32"), + np.arange(0, 5).astype("float32"), + np.random.rand(7, 3, 1).astype("float32"), + ] + compare_pytorch_and_py( + f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) + ) From 2a9ffd34fbdfa85603c6c820bb0f5fd7b661b720 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sun, 24 Nov 2024 20:21:55 -0800 Subject: [PATCH 21/26] Add single carry test --- tests/link/pytorch/test_basic.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 350da40188..93cd179622 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -425,7 +425,30 @@ def test_ScalarLoop_while(): np.testing.assert_allclose(res[1], np.array(expected[1])) -def test_ScalarLoop_Elemwise(): +def test_ScalarLoop_Elemwise_single_carries(): + n_steps = int64("n_steps") + x0 = float64("x0") + x = x0 * 2 + until = x >= 10 + + scalarop = ScalarLoop(init=[x0], update=[x], until=until) + op = Elemwise(scalarop) + + n_steps = pt.scalar("n_steps", dtype="int32") + x0 = pt.vector("x0", dtype="float32") + state, done = op(n_steps, x0) + + f = FunctionGraph([n_steps, x0], [state, done]) + args = [ + np.array(10).astype("int32"), + np.arange(0, 5).astype("float32"), + ] + compare_pytorch_and_py( + f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) + ) + + +def test_ScalarLoop_Elemwise_multi_carries(): n_steps = int64("n_steps") x0 = float64("x0") x1 = float64("x1") From 4ebdd154e1b8ad59d6bb21eb74ba682c922569c9 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Mon, 25 Nov 2024 09:42:20 -0800 Subject: [PATCH 22/26] Remove compiler disable --- pytensor/link/pytorch/dispatch/scalar.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index e5bf55fbf6..c27204f50d 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -67,7 +67,7 @@ def pytorch_funcify_Softplus(op, node, **kwargs): @pytorch_funcify.register(ScalarLoop) def pytorch_funicify_ScalarLoop(op, node, **kwargs): - update = pytorch_funcify(op.fgraph) + update = pytorch_funcify(op.fgraph, **kwargs) state_length = op.nout if op.is_while: @@ -96,4 +96,4 @@ def scalar_loop(steps, *start_and_constants): else: return carry - return torch.compiler.disable(scalar_loop, recursive=False) + return scalar_loop From 46e3e726326d73c8ec00a0c0d55dc62f4a49d2bf Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Thu, 28 Nov 2024 16:59:45 -0800 Subject: [PATCH 23/26] Better name --- pytensor/link/pytorch/dispatch/elemwise.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 79183da08d..fca67a244c 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -35,7 +35,7 @@ def elemwise_fn(*inputs): return base_fn(*inputs) elif isinstance(scalar_op, ScalarLoop): - return elemwise_scalar_loop(base_fn, op, node, **kwargs) + return elemwise_ravel_fn(base_fn, op, node, **kwargs) else: @@ -182,7 +182,7 @@ def softmax_grad(dy, sm): return softmax_grad -def elemwise_scalar_loop(base_fn, op, node, **kwargs): +def elemwise_ravel_fn(base_fn, op, node, **kwargs): """ ScalarLoop + Elemwise is too common to not work, but https://github.com/pymc-devs/pytensor/issues/1031, From d00f9e2e06a2a24f1905521337cc63f5ff4acc56 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Thu, 28 Nov 2024 17:07:42 -0800 Subject: [PATCH 24/26] Lint --- pytensor/link/pytorch/dispatch/elemwise.py | 1 + pytensor/link/pytorch/dispatch/scalar.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index fca67a244c..065de0ae5e 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -1,4 +1,5 @@ import importlib + import torch from pytensor.link.pytorch.dispatch.basic import pytorch_funcify diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index c27204f50d..7e6c068f21 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -8,8 +8,8 @@ Cast, ScalarOp, ) -from pytensor.scalar.math import Softplus from pytensor.scalar.loop import ScalarLoop +from pytensor.scalar.math import Softplus @pytorch_funcify.register(ScalarOp) @@ -60,6 +60,7 @@ def cast(x): return cast + @pytorch_funcify.register(Softplus) def pytorch_funcify_Softplus(op, node, **kwargs): return torch.nn.Softplus() From 5bd100e71a43cc0eab2f7fc9e0d6b8d75ba3b12d Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Thu, 28 Nov 2024 17:13:45 -0800 Subject: [PATCH 25/26] Better docstring --- pytensor/link/pytorch/dispatch/elemwise.py | 7 +++---- tests/link/pytorch/test_basic.py | 1 + 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index 065de0ae5e..ac1a352ea4 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -185,10 +185,9 @@ def softmax_grad(dy, sm): def elemwise_ravel_fn(base_fn, op, node, **kwargs): """ - ScalarLoop + Elemwise is too common - to not work, but https://github.com/pymc-devs/pytensor/issues/1031, - vmap won't allow it. Instead, we can ravel all the inputs, broadcasted - according to torch + Dispatch methods using `.item()` (ScalarLoop + Elemwise) is common, but vmap + in torch has a limitation: https://github.com/pymc-devs/pytensor/issues/1031, + Instead, we can ravel all the inputs, broadcasted according to torch """ n_outputs = len(node.outputs) diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 93cd179622..71d1243212 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -393,6 +393,7 @@ def test_pytorch_softplus(): f = FunctionGraph([x], [out]) compare_pytorch_and_py(f, [np.random.rand(3)]) + def test_ScalarLoop(): n_steps = int64("n_steps") x0 = float64("x0") From 521ad67597d44929dc199ff694078893389d7742 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Sun, 8 Dec 2024 08:46:30 -0800 Subject: [PATCH 26/26] Pr comments --- pytensor/link/pytorch/dispatch/elemwise.py | 2 +- pytensor/link/pytorch/dispatch/scalar.py | 1 - pytensor/link/pytorch/linker.py | 27 ++++------------------ 3 files changed, 6 insertions(+), 24 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/elemwise.py b/pytensor/link/pytorch/dispatch/elemwise.py index ac1a352ea4..c22945d914 100644 --- a/pytensor/link/pytorch/dispatch/elemwise.py +++ b/pytensor/link/pytorch/dispatch/elemwise.py @@ -198,7 +198,7 @@ def elemwise_fn(*inputs): out_shape = bcasted_inputs[0].size() out_size = out_shape.numel() - raveled_outputs = [torch.zeros(out_size) for out in node.outputs] + raveled_outputs = [torch.empty(out_size) for out in node.outputs] for i in range(out_size): core_outs = base_fn(*(inp[i] for inp in raveled_inputs)) diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 7e6c068f21..65170b1f53 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -1,7 +1,6 @@ import importlib import torch -import torch.compiler from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.scalar.basic import ( diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index 61a39fbea9..d47aa43dda 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -9,19 +9,6 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.gen_functors = [] - def input_filter(self, inp): - from pytensor.link.pytorch.dispatch import pytorch_typify - - return pytorch_typify(inp) - - def output_filter(self, var, out): - from torch import is_tensor - - if is_tensor(out): - return out.cpu() - else: - return out - def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): from pytensor.link.pytorch.dispatch import pytorch_funcify @@ -67,21 +54,22 @@ def __init__(self, fn, gen_functors): self.fn = torch.compile(fn) self.gen_functors = gen_functors.copy() - def __call__(self, *args, **kwargs): + def __call__(self, *inputs, **kwargs): import pytensor.link.utils # set attrs for n, fn in self.gen_functors: setattr(pytensor.link.utils, n[1:], fn) - res = self.fn(*args, **kwargs) + # Torch does not accept numpy inputs and may return GPU objects + outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs) # unset attrs for n, _ in self.gen_functors: if getattr(pytensor.link.utils, n[1:], False): delattr(pytensor.link.utils, n[1:]) - return res + return tuple(out.cpu().numpy() for out in outs) def __del__(self): del self.gen_functors @@ -89,12 +77,7 @@ def __del__(self): inner_fn = wrapper(fn, self.gen_functors) self.gen_functors = [] - # Torch does not accept numpy inputs and may return GPU objects - def create_outputs(*inputs, inner_fn=inner_fn): - outs = inner_fn(*(pytorch_typify(inp) for inp in inputs)) - return tuple(out.cpu().numpy() for out in outs) - - return create_outputs + return inner_fn def create_thunk_inputs(self, storage_map): thunk_inputs = []