1010from pytensor .graph .basic import Apply , Variable , equal_computations
1111from pytensor .graph .op import Op
1212from pytensor .graph .replace import _vectorize_node
13- from pytensor .misc .safe_asarray import _asarray
1413from pytensor .scalar import ScalarVariable
1514from pytensor .tensor .basic import (
1615 as_tensor_variable ,
@@ -389,7 +388,6 @@ def dist_params(self, node) -> Sequence[Variable]:
389388
390389 def perform (self , node , inputs , outputs ):
391390 rng_var_out , smpl_out = outputs
392-
393391 rng , size , * args = inputs
394392
395393 # Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
@@ -400,12 +398,11 @@ def perform(self, node, inputs, outputs):
400398
401399 if size is not None :
402400 size = tuple (size )
403- smpl_val = self .rng_fn (rng , * ([* args , size ]))
404-
405- if not isinstance (smpl_val , np .ndarray ) or str (smpl_val .dtype ) != self .dtype :
406- smpl_val = _asarray (smpl_val , dtype = self .dtype )
407401
408- smpl_out [0 ] = smpl_val
402+ smpl_out [0 ] = np .asarray (
403+ self .rng_fn (rng , * args , size ),
404+ dtype = self .dtype ,
405+ )
409406
410407 def grad (self , inputs , outputs ):
411408 return [
0 commit comments