Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import abc
import warnings
from collections.abc import Sequence
from copy import deepcopy
from typing import Any, cast

import numpy as np
Expand All @@ -23,6 +22,7 @@
from pytensor.tensor.random.type import RandomGeneratorType, RandomType
from pytensor.tensor.random.utils import (
compute_batch_shape,
custom_deepcopy,
explicit_expand_dims,
normalize_size_param,
)
Expand Down Expand Up @@ -421,7 +421,7 @@ def perform(self, node, inputs, outputs):

# Draw from `rng` if `self.inplace` is `True`, and from a copy of `rng` otherwise.
if not self.inplace:
rng = deepcopy(rng)
rng = custom_deepcopy(rng)

outputs[0][0] = rng
outputs[1][0] = np.asarray(
Expand Down
9 changes: 9 additions & 0 deletions pytensor/tensor/random/utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from collections.abc import Callable, Sequence
from copy import deepcopy
from functools import wraps
from itertools import zip_longest
from types import ModuleType
from typing import TYPE_CHECKING

import numpy as np
from numpy.random import Generator

from pytensor.compile.sharedvalue import shared
from pytensor.graph.basic import Constant, Variable
Expand Down Expand Up @@ -201,6 +203,13 @@ def normalize_size_param(
return shape


def custom_deepcopy(rng):
old_bitgen = rng.bit_generator
new_bitgen = type(old_bitgen)(deepcopy(old_bitgen._seed_seq))
new_bitgen.state = old_bitgen.state
return Generator(new_bitgen)


class RandomStream:
"""Module component with similar interface to `numpy.random.Generator`.

Expand Down
40 changes: 40 additions & 0 deletions tests/tensor/random/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import timeit
from copy import deepcopy

import numpy as np
import pytest

Expand All @@ -7,6 +10,7 @@
from pytensor.tensor.random.utils import (
RandomStream,
broadcast_params,
custom_deepcopy,
supp_shape_from_ref_param_shape,
)
from pytensor.tensor.type import matrix, tensor
Expand Down Expand Up @@ -327,3 +331,39 @@ def test_supp_shape_from_ref_param_shape():
ref_param_idx=1,
)
assert res == (3, 4)


def test_custom_deepcopy_matches_deepcopy():
rng = np.random.default_rng(123)

dp = deepcopy(rng).bit_generator
fc = custom_deepcopy(rng).bit_generator

# Same state
assert dp.state == fc.state
# Same seed sequence
assert dp.seed_seq.state == fc.seed_seq.state


def test_custom_deepcopy_output_identical():
rng = np.random.default_rng(123)

rng1 = deepcopy(rng)
rng2 = custom_deepcopy(rng)

# Generate numbers from each
x1 = rng1.normal(size=10)
x2 = rng2.normal(size=10)

assert np.allclose(x1, x2)


@pytest.mark.performance
def test_custom_deepcopy_faster_than_deepcopy():
rng = np.random.default_rng()

t_dp = timeit.timeit(lambda: deepcopy(rng), number=2000)
t_fc = timeit.timeit(lambda: custom_deepcopy(rng), number=2000)

# Fast copy should be at least 20% faster
assert t_fc < t_dp * 0.8