-
Notifications
You must be signed in to change notification settings - Fork 145
Implementation of random variables with PyTorch backend #1075
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
69d62c7
fcd643a
7eb77df
f93a20e
1f863f5
ac35367
288d2c4
4b4f8d0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
from functools import singledispatch | ||
|
||
import numpy.random | ||
import torch | ||
|
||
import pytensor.tensor.random.basic as ptr | ||
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify | ||
|
||
|
||
@pytorch_typify.register(numpy.random.Generator) | ||
def pytorch_typify_Generator(rng, **kwargs): | ||
# XXX: Check if there is a better way. | ||
# Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp) | ||
seed = torch.from_numpy(rng.integers([2**32])) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You have to copy the rng before calling |
||
return torch.manual_seed(seed) | ||
|
||
|
||
@pytorch_typify.register(torch._C.Generator) | ||
def pytorch_typify_pass_generator(rng, **kwargs): | ||
return rng | ||
|
||
|
||
@pytorch_funcify.register(ptr.RandomVariable) | ||
def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs): | ||
rv = node.outputs[1] | ||
out_dtype = rv.type.dtype | ||
shape = rv.type.shape | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. shape is not guaranteed to be static. Use the |
||
rv_sample = pytorch_sample_fn(op, node=node) | ||
|
||
def sample_fn(rng, size, *args): | ||
new_rng = torch.Generator(device="cpu") | ||
new_rng.set_state(rng.get_state().clone()) | ||
return rv_sample(new_rng, shape, out_dtype, *args) | ||
|
||
return sample_fn | ||
|
||
|
||
@singledispatch | ||
def pytorch_sample_fn(op, node): | ||
name = op.name | ||
raise NotImplementedError( | ||
f"No PyTorch implementation for the given distribution: {name}" | ||
) | ||
|
||
|
||
@pytorch_sample_fn.register(ptr.BernoulliRV) | ||
def pytorch_sample_fn_bernoulli(op, node): | ||
def sample_fn(gen, size, dtype, p): | ||
sample = torch.bernoulli(torch.broadcast_to(p, size), generator=gen) | ||
Ch0ronomato marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return (gen, sample) | ||
|
||
return sample_fn | ||
|
||
|
||
@pytorch_sample_fn.register(ptr.BinomialRV) | ||
def pytorch_sample_fn_binomial(op, node): | ||
def sample_fn(gen, size, dtype, n, p): | ||
sample = torch.binomial( | ||
torch.broadcast_to(n, size).to(torch.float32), | ||
torch.broadcast_to(p, size).to(torch.float32), | ||
generator=gen, | ||
) | ||
return (gen, sample) | ||
|
||
return sample_fn | ||
|
||
|
||
@pytorch_sample_fn.register(ptr.UniformRV) | ||
def pytorch_sample_fn_uniform(op, node): | ||
def sample_fn(gen, size, dtype, low, high): | ||
sample = torch.FloatTensor(size) | ||
sample.uniform_(low.item(), high.item(), generator=gen) | ||
return (gen, sample) | ||
|
||
return sample_fn |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,5 @@ | ||
from numpy.random import Generator, RandomState | ||
|
||
from pytensor.link.basic import JITLinker | ||
from pytensor.link.utils import unique_name_generator | ||
|
||
|
@@ -72,7 +74,9 @@ def __call__(self, *inputs, **kwargs): | |
if getattr(pytensor.link.utils, n[1:], False): | ||
delattr(pytensor.link.utils, n[1:]) | ||
|
||
return tuple(out.cpu().numpy() for out in outs) | ||
return tuple( | ||
out.cpu().numpy() if torch.is_tensor(out) else out for out in outs | ||
) | ||
|
||
def __del__(self): | ||
del self.gen_functors | ||
|
@@ -83,9 +87,16 @@ def __del__(self): | |
return inner_fn | ||
|
||
def create_thunk_inputs(self, storage_map): | ||
from pytensor.link.pytorch.dispatch import pytorch_typify | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You'll need to copy the logic with SharedVariables in JAX to emmit a warning and use different variables. You can refactor the logic so it's not duplicated |
||
|
||
thunk_inputs = [] | ||
for n in self.fgraph.inputs: | ||
sinput = storage_map[n] | ||
if isinstance(sinput[0], RandomState | Generator): | ||
new_value = pytorch_typify( | ||
sinput[0], dtype=getattr(sinput[0], "dtype", None) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this needed? |
||
) | ||
sinput[0] = new_value | ||
thunk_inputs.append(sinput) | ||
|
||
return thunk_inputs |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
import pytensor.tensor as pt | ||
from pytensor.compile.function import function | ||
from pytensor.compile.sharedvalue import shared | ||
from pytensor.link.pytorch.dispatch.basic import pytorch_typify | ||
from tests.link.pytorch.test_basic import pytorch_mode | ||
|
||
|
||
torch = pytest.importorskip("torch") | ||
|
||
|
||
@pytest.mark.parametrize("update", [(True), (False)]) | ||
def test_random_updates(update): | ||
original = np.random.default_rng(seed=123) | ||
original_torch = pytorch_typify(original) | ||
rng = shared(original, name="rng", borrow=False) | ||
rv = pt.random.bernoulli(0.5, name="y", rng=rng) | ||
next_rng, x = rv.owner.outputs | ||
x.dprint() | ||
f = function([], x, updates={rng: next_rng} if update else None, mode="PYTORCH") | ||
draws = np.stack([f() for _ in range(5)]) | ||
# assert we are getting different values | ||
if update: | ||
assert draws.sum() < 5 and draws.sum() >= 1 | ||
# assert we have a new rng | ||
rng_value = rng.get_value(borrow=True) # we can't copy torch generator | ||
assert torch.eq(rng_value.get_state(), original_torch.get_state()) | ||
else: | ||
pass | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"size,p", | ||
[ | ||
((1000,), 0.5), | ||
(None, 0.5), | ||
((1000, 4), 0.5), | ||
((10, 2), np.array([0.5, 0.3])), | ||
((1000, 10, 2), np.array([0.5, 0.3])), | ||
], | ||
) | ||
def test_random_bernoulli(size, p): | ||
rng = shared(np.random.default_rng(123)) | ||
|
||
g = pt.random.bernoulli(p, size=size, rng=rng) | ||
g_fn = function([], g, mode=pytorch_mode) | ||
samples = g_fn() | ||
samples_mean = samples.mean(axis=0) if samples.shape else samples | ||
np.testing.assert_allclose(samples_mean, 0.5, 1) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"size,n,p,update", | ||
[ | ||
((1000,), 10, 0.5, False), | ||
((1000, 4), 10, 0.5, False), | ||
((1000, 2), np.array([10, 40]), np.array([0.5, 0.3]), True), | ||
], | ||
) | ||
def test_binomial(size, n, p, update): | ||
rng = shared(np.random.default_rng(123)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We need tests that confirm the original rng was not affected |
||
rv = pt.random.binomial(n, p, size=size, rng=rng) | ||
next_rng, *_ = rv.owner.inputs | ||
g_fn = function( | ||
[], rv, mode=pytorch_mode, updates={rng: next_rng} if update else None | ||
) | ||
samples = g_fn() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You should call twice. In this case, because you did not set updates you should get the same draws back. See https://pytensor.readthedocs.io/en/latest/tutorial/prng.html for details You should also test with updates separately There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I updated this to include a test without the update, but I'm not getting the same draws. I'll read through the article and see if I can see why |
||
if not update: | ||
np.testing.assert_allclose(samples, g_fn(), rtol=0.1) | ||
np.testing.assert_allclose(samples.mean(axis=0), n * p, rtol=0.1) | ||
np.testing.assert_allclose( | ||
samples.std(axis=0), np.sqrt(n * p * (1 - p)), rtol=0.2 | ||
) | ||
else: | ||
second_samples = g_fn() | ||
np.testing.assert_array_equal(second_samples, samples) |
Uh oh!
There was an error while loading. Please reload this page.