44from numpy .random import Generator
55
66import pytensor .tensor .random .basic as ptr
7- from pytensor .graph import Constant
87from pytensor .link .pytorch .dispatch .basic import pytorch_funcify , pytorch_typify
9- from pytensor .tensor .type_other import NoneTypeT
108
119
1210@pytorch_typify .register (Generator )
1311def pytorch_typify_Generator (rng , ** kwargs ):
12+ # XXX: Check if there is a better way.
13+ # Numpy uses PCG64 while Torch uses Mersenne-Twister (https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/CPUGeneratorImpl.cpp)
1414 state = rng .__getstate__ ()
15- state ["pytorch_state" ] = torch .manual_seed (123 ).get_state () # XXX: replace
15+ seed = torch .from_numpy (rng .integers ([2 ** 32 ]))
16+ state ["pytorch_gen" ] = torch .manual_seed (seed )
1617 return state
1718
1819
1920@pytorch_funcify .register (ptr .RandomVariable )
2021def torch_funcify_RandomVariable (op : ptr .RandomVariable , node , ** kwargs ):
2122 rv = node .outputs [1 ]
2223 out_dtype = rv .type .dtype
23- static_shape = rv .type .shape
24- batch_ndim = op .batch_ndim (node )
25-
26- # Try to pass static size directly to JAX
27- static_size = static_shape [:batch_ndim ]
28- if None in static_size :
29- # Sometimes size can be constant folded during rewrites,
30- # without the RandomVariable node being updated with new static types
31- size_param = op .size_param (node )
32- if isinstance (size_param , Constant ) and not isinstance (
33- size_param .type , NoneTypeT
34- ):
35- static_size = tuple (size_param .data )
24+ shape = rv .type .shape
3625
3726 def sample_fn (rng , size , * parameters ):
38- return pytorch_sample_fn (op , node = node )(
39- rng , static_size , out_dtype , * parameters
40- )
27+ return pytorch_sample_fn (op , node = node )(rng , shape , out_dtype , * parameters )
4128
4229 return sample_fn
4330
@@ -53,10 +40,22 @@ def pytorch_sample_fn(op, node):
5340@pytorch_sample_fn .register (ptr .BernoulliRV )
5441def pytorch_sample_fn_bernoulli (op , node ):
5542 def sample_fn (rng , size , dtype , p ):
56- # XXX replace
57- state_ = rng ["pytorch_state" ]
58- gen = torch .Generator ().set_state (state_ )
59- sample = torch .bernoulli (torch .expand_copy (p , size ), generator = gen )
60- return (rng , sample )
43+ gen = rng ["pytorch_gen" ]
44+ sample = torch .bernoulli (torch .broadcast_to (p , size ), generator = gen )
45+ return (gen , sample )
46+
47+ return sample_fn
48+
49+
50+ @pytorch_sample_fn .register (ptr .BinomialRV )
51+ def pytorch_sample_fn_binomial (op , node ):
52+ def sample_fn (rng , size , dtype , n , p ):
53+ gen = rng ["pytorch_gen" ]
54+ sample = torch .binomial (
55+ torch .broadcast_to (n .to (p .dtype ), size ),
56+ torch .broadcast_to (p , size ),
57+ generator = gen ,
58+ )
59+ return (gen , sample )
6160
6261 return sample_fn
0 commit comments