diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index 4caabf3e03..8c594fa4f6 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -12,4 +12,5 @@ import pytensor.link.pytorch.dispatch.sort import pytensor.link.pytorch.dispatch.subtensor import pytensor.link.pytorch.dispatch.blockwise +import pytensor.link.pytorch.dispatch.random # isort: on diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index ef4bf10637..cda2a03c56 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -37,6 +37,7 @@ def pytorch_typify_tensor(data, dtype=None, **kwargs): @pytorch_typify.register(slice) +@pytorch_typify.register(dict) @pytorch_typify.register(NoneType) @pytorch_typify.register(np.number) def pytorch_typify_no_conversion_needed(data, **kwargs): diff --git a/pytensor/link/pytorch/dispatch/random.py b/pytensor/link/pytorch/dispatch/random.py new file mode 100644 index 0000000000..fc62922a27 --- /dev/null +++ b/pytensor/link/pytorch/dispatch/random.py @@ -0,0 +1,75 @@ +from functools import singledispatch + +import numpy.random +import torch + +import pytensor.tensor.random.basic as ptr +from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify + + +@pytorch_typify.register(numpy.random.Generator) +def pytorch_typify_Generator(rng, **kwargs): + # XXX: Check if there is a better way. + # Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp) + seed = torch.from_numpy(rng.integers([2**32])) + return torch.manual_seed(seed) + + +@pytorch_typify.register(torch._C.Generator) +def pytorch_typify_pass_generator(rng, **kwargs): + return rng + + +@pytorch_funcify.register(ptr.RandomVariable) +def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): + rv = node.outputs[1] + out_dtype = rv.type.dtype + shape = rv.type.shape + rv_sample = pytorch_sample_fn(op, node=node) + + def sample_fn(rng, size, *args): + new_rng = torch.Generator(device="cpu") + new_rng.set_state(rng.get_state().clone()) + return rv_sample(new_rng, shape, out_dtype, *args) + + return sample_fn + + +@singledispatch +def pytorch_sample_fn(op, node): + name = op.name + raise NotImplementedError( + f"No PyTorch implementation for the given distribution: {name}" + ) + + +@pytorch_sample_fn.register(ptr.BernoulliRV) +def pytorch_sample_fn_bernoulli(op, node): + def sample_fn(gen, size, dtype, p): + sample = torch.bernoulli(torch.broadcast_to(p, size), generator=gen) + return (gen, sample) + + return sample_fn + + +@pytorch_sample_fn.register(ptr.BinomialRV) +def pytorch_sample_fn_binomial(op, node): + def sample_fn(gen, size, dtype, n, p): + sample = torch.binomial( + torch.broadcast_to(n, size).to(torch.float32), + torch.broadcast_to(p, size).to(torch.float32), + generator=gen, + ) + return (gen, sample) + + return sample_fn + + +@pytorch_sample_fn.register(ptr.UniformRV) +def pytorch_sample_fn_uniform(op, node): + def sample_fn(gen, size, dtype, low, high): + sample = torch.FloatTensor(size) + sample.uniform_(low.item(), high.item(), generator=gen) + return (gen, sample) + + return sample_fn diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index b8475e3157..2ed721e708 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -1,3 +1,5 @@ +from numpy.random import Generator, RandomState + from pytensor.link.basic import JITLinker from pytensor.link.utils import unique_name_generator @@ -72,7 +74,9 @@ def __call__(self, *inputs, **kwargs): if getattr(pytensor.link.utils, n[1:], False): delattr(pytensor.link.utils, n[1:]) - return tuple(out.cpu().numpy() for out in outs) + return tuple( + out.cpu().numpy() if torch.is_tensor(out) else out for out in outs + ) def __del__(self): del self.gen_functors @@ -83,9 +87,16 @@ def __del__(self): return inner_fn def create_thunk_inputs(self, storage_map): + from pytensor.link.pytorch.dispatch import pytorch_typify + thunk_inputs = [] for n in self.fgraph.inputs: sinput = storage_map[n] + if isinstance(sinput[0], RandomState | Generator): + new_value = pytorch_typify( + sinput[0], dtype=getattr(sinput[0], "dtype", None) + ) + sinput[0] = new_value thunk_inputs.append(sinput) return thunk_inputs diff --git a/tests/link/pytorch/test_random.py b/tests/link/pytorch/test_random.py new file mode 100644 index 0000000000..22c9aa28c3 --- /dev/null +++ b/tests/link/pytorch/test_random.py @@ -0,0 +1,78 @@ +import numpy as np +import pytest + +import pytensor.tensor as pt +from pytensor.compile.function import function +from pytensor.compile.sharedvalue import shared +from pytensor.link.pytorch.dispatch.basic import pytorch_typify +from tests.link.pytorch.test_basic import pytorch_mode + + +torch = pytest.importorskip("torch") + + +@pytest.mark.parametrize("update", [(True), (False)]) +def test_random_updates(update): + original = np.random.default_rng(seed=123) + original_torch = pytorch_typify(original) + rng = shared(original, name="rng", borrow=False) + rv = pt.random.bernoulli(0.5, name="y", rng=rng) + next_rng, x = rv.owner.outputs + x.dprint() + f = function([], x, updates={rng: next_rng} if update else None, mode="PYTORCH") + draws = np.stack([f() for _ in range(5)]) + # assert we are getting different values + if update: + assert draws.sum() < 5 and draws.sum() >= 1 + # assert we have a new rng + rng_value = rng.get_value(borrow=True) # we can't copy torch generator + assert torch.eq(rng_value.get_state(), original_torch.get_state()) + else: + pass + + +@pytest.mark.parametrize( + "size,p", + [ + ((1000,), 0.5), + (None, 0.5), + ((1000, 4), 0.5), + ((10, 2), np.array([0.5, 0.3])), + ((1000, 10, 2), np.array([0.5, 0.3])), + ], +) +def test_random_bernoulli(size, p): + rng = shared(np.random.default_rng(123)) + + g = pt.random.bernoulli(p, size=size, rng=rng) + g_fn = function([], g, mode=pytorch_mode) + samples = g_fn() + samples_mean = samples.mean(axis=0) if samples.shape else samples + np.testing.assert_allclose(samples_mean, 0.5, 1) + + +@pytest.mark.parametrize( + "size,n,p,update", + [ + ((1000,), 10, 0.5, False), + ((1000, 4), 10, 0.5, False), + ((1000, 2), np.array([10, 40]), np.array([0.5, 0.3]), True), + ], +) +def test_binomial(size, n, p, update): + rng = shared(np.random.default_rng(123)) + rv = pt.random.binomial(n, p, size=size, rng=rng) + next_rng, *_ = rv.owner.inputs + g_fn = function( + [], rv, mode=pytorch_mode, updates={rng: next_rng} if update else None + ) + samples = g_fn() + if not update: + np.testing.assert_allclose(samples, g_fn(), rtol=0.1) + np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1) + np.testing.assert_allclose( + samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.2 + ) + else: + second_samples = g_fn() + np.testing.assert_array_equal(second_samples, samples)