Skip to content
Merged
35 changes: 35 additions & 0 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,41 @@
def __contains__(self, key):
return key in self.named_vars or self.name_for(key) in self.named_vars

def __copy__(self):
return self.copy()

Check warning on line 1575 in pymc/model/core.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/core.py#L1574-L1575

Added lines #L1574 - L1575 were not covered by tests

def __deepcopy__(self, _):
return self.copy()

Check warning on line 1578 in pymc/model/core.py

View check run for this annotation

Codecov / codecov/patch

pymc/model/core.py#L1578

Added line #L1578 was not covered by tests

def copy(self):
"""
Clone the model

To access variables in the cloned model use `cloned_model["var_name"]`.

Examples
--------
.. code-block:: python

import pymc as pm
import copy

with pm.Model() as m:
p = pm.Beta("p", 1, 1)
x = pm.Bernoulli("x", p=p, shape=(3,))

clone_m = copy.copy(m)

# Access cloned variables by name
clone_x = clone_m["x"]

# z will be part of clone_m but not m
z = pm.Deterministic("z", clone_x + 1)
"""
from pymc.model.fgraph import clone_model

return clone_model(self)

def replace_rvs_by_values(
self,
graphs: Sequence[TensorVariable],
Expand Down
10 changes: 10 additions & 0 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from copy import copy, deepcopy

import pytensor
Expand Down Expand Up @@ -158,6 +160,14 @@ def fgraph_from_model(
"Nested sub-models cannot be converted to fgraph. Convert the parent model instead"
)

if any(
("_rotated_" in var_name or "_hsgp_coeffs_" in var_name) for var_name in model.named_vars
):
warnings.warn(
"Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883",
UserWarning,
)

# Collect PyTensor variables
rvs_to_values = model.rvs_to_values
rvs = list(rvs_to_values.keys())
Expand Down
46 changes: 46 additions & 0 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import pickle
import threading
import traceback
Expand Down Expand Up @@ -1761,3 +1762,48 @@ def test_graphviz_call_function(self, var_names, filenames) -> None:
figsize=None,
dpi=300,
)


class TestModelCopy:
@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
def test_copy_model(self, copy_method) -> None:
with pm.Model() as simple_model:
pm.Normal("y")

copy_simple_model = copy_method(simple_model)

with simple_model:
simple_model_prior_predictive = pm.sample_prior_predictive(samples=1, random_seed=42)

with copy_simple_model:
z = pm.Deterministic("z", copy_simple_model["y"] + 1)
copy_simple_model_prior_predictive = pm.sample_prior_predictive(
samples=1, random_seed=42
)

assert (
simple_model_prior_predictive["prior"]["y"].values
== copy_simple_model_prior_predictive["prior"]["y"].values
)

assert "z" in copy_simple_model.named_vars
assert "z" not in simple_model.named_vars
assert (
copy_simple_model_prior_predictive["prior"]["z"].values
== 1 + simple_model_prior_predictive["prior"]["y"].values
)

@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
def test_guassian_process_copy_failure(self, copy_method) -> None:
with pm.Model() as gaussian_process_model:
ell = pm.Gamma("ell", alpha=2, beta=1)
cov = 2 * pm.gp.cov.ExpQuad(1, ell)
gp = pm.gp.Latent(cov_func=cov)
f = gp.prior("f", X=np.arange(10)[:, None])
pm.Normal("y", f * 2)

with pytest.warns(
UserWarning,
match="Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883",
):
copy_method(gaussian_process_model)
Loading