Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@
import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.subtensor
import pytensor.link.pytorch.dispatch.blockwise
import pytensor.link.pytorch.dispatch.random
# isort: on
1 change: 1 addition & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def pytorch_typify_tensor(data, dtype=None, **kwargs):


@pytorch_typify.register(slice)
@pytorch_typify.register(dict)
@pytorch_typify.register(NoneType)
@pytorch_typify.register(np.number)
def pytorch_typify_no_conversion_needed(data, **kwargs):
Expand Down
75 changes: 75 additions & 0 deletions pytensor/link/pytorch/dispatch/random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from functools import singledispatch

import numpy.random
import torch

import pytensor.tensor.random.basic as ptr
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify


@pytorch_typify.register(numpy.random.Generator)
def pytorch_typify_Generator(rng, **kwargs):
# XXX: Check if there is a better way.
# Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp)
seed = torch.from_numpy(rng.integers([2**32]))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You have to copy the rng before calling rng.integers we don't want to modify the original one

return torch.manual_seed(seed)


@pytorch_typify.register(torch._C.Generator)
def pytorch_typify_pass_generator(rng, **kwargs):
return rng


@pytorch_funcify.register(ptr.RandomVariable)
def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
rv = node.outputs[1]
out_dtype = rv.type.dtype
shape = rv.type.shape
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shape is not guaranteed to be static. Use the size argument passed at runtime? Or add an if/else if this was an optimization

rv_sample = pytorch_sample_fn(op, node=node)

def sample_fn(rng, size, *args):
new_rng = torch.Generator(device="cpu")
new_rng.set_state(rng.get_state().clone())
return rv_sample(new_rng, shape, out_dtype, *args)

return sample_fn


@singledispatch
def pytorch_sample_fn(op, node):
name = op.name
raise NotImplementedError(
f"No PyTorch implementation for the given distribution: {name}"
)


@pytorch_sample_fn.register(ptr.BernoulliRV)
def pytorch_sample_fn_bernoulli(op, node):
def sample_fn(gen, size, dtype, p):
sample = torch.bernoulli(torch.broadcast_to(p, size), generator=gen)
return (gen, sample)

return sample_fn


@pytorch_sample_fn.register(ptr.BinomialRV)
def pytorch_sample_fn_binomial(op, node):
def sample_fn(gen, size, dtype, n, p):
sample = torch.binomial(
torch.broadcast_to(n, size).to(torch.float32),
torch.broadcast_to(p, size).to(torch.float32),
generator=gen,
)
return (gen, sample)

return sample_fn


@pytorch_sample_fn.register(ptr.UniformRV)
def pytorch_sample_fn_uniform(op, node):
def sample_fn(gen, size, dtype, low, high):
sample = torch.FloatTensor(size)
sample.uniform_(low.item(), high.item(), generator=gen)
return (gen, sample)

return sample_fn
13 changes: 12 additions & 1 deletion pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from numpy.random import Generator, RandomState

from pytensor.link.basic import JITLinker
from pytensor.link.utils import unique_name_generator

Expand Down Expand Up @@ -72,7 +74,9 @@ def __call__(self, *inputs, **kwargs):
if getattr(pytensor.link.utils, n[1:], False):
delattr(pytensor.link.utils, n[1:])

return tuple(out.cpu().numpy() for out in outs)
return tuple(
out.cpu().numpy() if torch.is_tensor(out) else out for out in outs
)

def __del__(self):
del self.gen_functors
Expand All @@ -83,9 +87,16 @@ def __del__(self):
return inner_fn

def create_thunk_inputs(self, storage_map):
from pytensor.link.pytorch.dispatch import pytorch_typify
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You'll need to copy the logic with SharedVariables in JAX to emmit a warning and use different variables. You can refactor the logic so it's not duplicated


thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
if isinstance(sinput[0], RandomState | Generator):
new_value = pytorch_typify(
sinput[0], dtype=getattr(sinput[0], "dtype", None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this needed?

)
sinput[0] = new_value
thunk_inputs.append(sinput)

return thunk_inputs
78 changes: 78 additions & 0 deletions tests/link/pytorch/test_random.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np
import pytest

import pytensor.tensor as pt
from pytensor.compile.function import function
from pytensor.compile.sharedvalue import shared
from pytensor.link.pytorch.dispatch.basic import pytorch_typify
from tests.link.pytorch.test_basic import pytorch_mode


torch = pytest.importorskip("torch")


@pytest.mark.parametrize("update", [(True), (False)])
def test_random_updates(update):
original = np.random.default_rng(seed=123)
original_torch = pytorch_typify(original)
rng = shared(original, name="rng", borrow=False)
rv = pt.random.bernoulli(0.5, name="y", rng=rng)
next_rng, x = rv.owner.outputs
x.dprint()
f = function([], x, updates={rng: next_rng} if update else None, mode="PYTORCH")
draws = np.stack([f() for _ in range(5)])
# assert we are getting different values
if update:
assert draws.sum() < 5 and draws.sum() >= 1
# assert we have a new rng
rng_value = rng.get_value(borrow=True) # we can't copy torch generator
assert torch.eq(rng_value.get_state(), original_torch.get_state())
else:
pass


@pytest.mark.parametrize(
"size,p",
[
((1000,), 0.5),
(None, 0.5),
((1000, 4), 0.5),
((10, 2), np.array([0.5, 0.3])),
((1000, 10, 2), np.array([0.5, 0.3])),
],
)
def test_random_bernoulli(size, p):
rng = shared(np.random.default_rng(123))

g = pt.random.bernoulli(p, size=size, rng=rng)
g_fn = function([], g, mode=pytorch_mode)
samples = g_fn()
samples_mean = samples.mean(axis=0) if samples.shape else samples
np.testing.assert_allclose(samples_mean, 0.5, 1)


@pytest.mark.parametrize(
"size,n,p,update",
[
((1000,), 10, 0.5, False),
((1000, 4), 10, 0.5, False),
((1000, 2), np.array([10, 40]), np.array([0.5, 0.3]), True),
],
)
def test_binomial(size, n, p, update):
rng = shared(np.random.default_rng(123))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need tests that confirm the original rng was not affected

rv = pt.random.binomial(n, p, size=size, rng=rng)
next_rng, *_ = rv.owner.inputs
g_fn = function(
[], rv, mode=pytorch_mode, updates={rng: next_rng} if update else None
)
samples = g_fn()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should call twice. In this case, because you did not set updates you should get the same draws back. See https://pytensor.readthedocs.io/en/latest/tutorial/prng.html for details

You should also test with updates separately

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I updated this to include a test without the update, but I'm not getting the same draws. I'll read through the article and see if I can see why

if not update:
np.testing.assert_allclose(samples, g_fn(), rtol=0.1)
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1)
np.testing.assert_allclose(
samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.2
)
else:
second_samples = g_fn()
np.testing.assert_array_equal(second_samples, samples)
Loading