|
30 | 30 | from aesara.tensor.basic import as_tensor_variable
|
31 | 31 | from aesara.tensor.elemwise import Elemwise
|
32 | 32 | from aesara.tensor.random.op import RandomVariable
|
33 |
| -from aesara.tensor.random.var import RandomStateSharedVariable |
34 | 33 | from aesara.tensor.var import TensorVariable
|
35 | 34 | from typing_extensions import TypeAlias
|
36 | 35 |
|
@@ -358,23 +357,6 @@ def dist(
|
358 | 357 | replicate_shape = cast(StrongShape, shape[:-1])
|
359 | 358 | rv_out = change_rv_size(rv_var=rv_out, new_size=replicate_shape, expand=True)
|
360 | 359 |
|
361 |
| - rng = kwargs.pop("rng", None) |
362 |
| - if ( |
363 |
| - rv_out.owner |
364 |
| - and isinstance(rv_out.owner.op, RandomVariable) |
365 |
| - and isinstance(rng, RandomStateSharedVariable) |
366 |
| - and not getattr(rng, "default_update", None) |
367 |
| - ): |
368 |
| - # This tells `aesara.function` that the shared RNG variable |
369 |
| - # is mutable, which--in turn--tells the `FunctionGraph` |
370 |
| - # `Supervisor` feature to allow in-place updates on the variable. |
371 |
| - # Without it, the `RandomVariable`s could not be optimized to allow |
372 |
| - # in-place RNG updates, forcing all sample results from compiled |
373 |
| - # functions to be the same on repeated evaluations. |
374 |
| - new_rng = rv_out.owner.outputs[0] |
375 |
| - rv_out.update = (rng, new_rng) |
376 |
| - rng.default_update = new_rng |
377 |
| - |
378 | 360 | rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
|
379 | 361 | rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
|
380 | 362 | rv_out.random = _make_nice_attr_error("rv.random()", "rv.eval()")
|
@@ -589,27 +571,6 @@ def dist(
|
589 | 571 | replicate_shape = cast(StrongShape, shape[:-1])
|
590 | 572 | graph = cls.change_size(rv=graph, new_size=replicate_shape, expand=True)
|
591 | 573 |
|
592 |
| - rngs = kwargs.pop("rngs", None) |
593 |
| - if rngs is not None: |
594 |
| - graph_rvs = cls.graph_rvs(graph) |
595 |
| - assert len(rngs) == len(graph_rvs) |
596 |
| - for rng, rv_out in zip(rngs, graph_rvs): |
597 |
| - if ( |
598 |
| - rv_out.owner |
599 |
| - and isinstance(rv_out.owner.op, RandomVariable) |
600 |
| - and isinstance(rng, RandomStateSharedVariable) |
601 |
| - and not getattr(rng, "default_update", None) |
602 |
| - ): |
603 |
| - # This tells `aesara.function` that the shared RNG variable |
604 |
| - # is mutable, which--in turn--tells the `FunctionGraph` |
605 |
| - # `Supervisor` feature to allow in-place updates on the variable. |
606 |
| - # Without it, the `RandomVariable`s could not be optimized to allow |
607 |
| - # in-place RNG updates, forcing all sample results from compiled |
608 |
| - # functions to be the same on repeated evaluations. |
609 |
| - new_rng = rv_out.owner.outputs[0] |
610 |
| - rv_out.update = (rng, new_rng) |
611 |
| - rng.default_update = new_rng |
612 |
| - |
613 | 574 | # TODO: Create new attr error stating that these are not available for DerivedDistribution
|
614 | 575 | # rv_out.logp = _make_nice_attr_error("rv.logp(x)", "pm.logp(rv, x)")
|
615 | 576 | # rv_out.logcdf = _make_nice_attr_error("rv.logcdf(x)", "pm.logcdf(rv, x)")
|
|
0 commit comments