11from functools import singledispatch
22
3- import numpy as np
3+ import numpy . random
44import torch
5- from numpy .random import Generator
65
76import pytensor .tensor .random .basic as ptr
87from pytensor .link .pytorch .dispatch .basic import pytorch_funcify , pytorch_typify
98
109
11- @pytorch_typify .register (Generator )
10+ @pytorch_typify .register (numpy . random . Generator )
1211def pytorch_typify_Generator (rng , ** kwargs ):
1312 # XXX: Check if there is a better way.
1413 # Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp)
15- state = rng .__getstate__ ()
16- rng_copy = np .random .default_rng ()
17- rng_copy .bit_generator .state = rng .bit_generator .state
18- seed = torch .from_numpy (rng_copy .integers ([2 ** 32 ]))
19- state ["pytorch_gen" ] = torch .manual_seed (seed )
20- return state
14+ seed = torch .from_numpy (rng .integers ([2 ** 32 ]))
15+ return torch .manual_seed (seed )
16+
17+
18+ @pytorch_typify .register (torch ._C .Generator )
19+ def pytorch_typify_pass_generator (rng , ** kwargs ):
20+ return rng
2121
2222
2323@pytorch_funcify .register (ptr .RandomVariable )
@@ -28,7 +28,9 @@ def torch_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
2828 rv_sample = pytorch_sample_fn (op , node = node )
2929
3030 def sample_fn (rng , size , * args ):
31- return rv_sample (rng , shape , out_dtype , * args )
31+ new_rng = torch .Generator (device = "cpu" )
32+ new_rng .set_state (rng .get_state ().clone ())
33+ return rv_sample (new_rng , shape , out_dtype , * args )
3234
3335 return sample_fn
3436
@@ -43,25 +45,31 @@ def pytorch_sample_fn(op, node):
4345
4446@pytorch_sample_fn .register (ptr .BernoulliRV )
4547def pytorch_sample_fn_bernoulli (op , node ):
46- def sample_fn (rng , size , dtype , p ):
47- gen = rng ["pytorch_gen" ]
48+ def sample_fn (gen , size , dtype , p ):
4849 sample = torch .bernoulli (torch .broadcast_to (p , size ), generator = gen )
49- rng ["pytorch_gen" ] = gen
50- return (rng , sample )
50+ return (gen , sample )
5151
5252 return sample_fn
5353
5454
5555@pytorch_sample_fn .register (ptr .BinomialRV )
5656def pytorch_sample_fn_binomial (op , node ):
57- def sample_fn (rng , size , dtype , n , p ):
58- gen = rng ["pytorch_gen" ]
57+ def sample_fn (gen , size , dtype , n , p ):
5958 sample = torch .binomial (
60- torch .broadcast_to (n , size ),
61- torch .broadcast_to (p , size ),
59+ torch .broadcast_to (n , size ). to ( torch . float32 ) ,
60+ torch .broadcast_to (p , size ). to ( torch . float32 ) ,
6261 generator = gen ,
6362 )
64- rng ["pytorch_gen" ] = gen
65- return (rng , sample )
63+ return (gen , sample )
64+
65+ return sample_fn
66+
67+
68+ @pytorch_sample_fn .register (ptr .UniformRV )
69+ def pytorch_sample_fn_uniform (op , node ):
70+ def sample_fn (gen , size , dtype , low , high ):
71+ sample = torch .FloatTensor (size )
72+ sample .uniform_ (low .item (), high .item (), generator = gen )
73+ return (gen , sample )
6674
6775 return sample_fn
0 commit comments