Skip to content
Merged
34 changes: 34 additions & 0 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1571,6 +1571,40 @@ def __getitem__(self, key):
def __contains__(self, key):
return key in self.named_vars or self.name_for(key) in self.named_vars

def __copy__(self):
return self.copy()

def __deepcopy__(self, _):
return self.copy()

def copy(self):
"""
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph.
Constants are not cloned and if guassian process variables are detected then a warning will be triggered.
Copy link
Member

@ricardoV94 ricardoV94 Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph.
Constants are not cloned and if guassian process variables are detected then a warning will be triggered.
Clone the model.
To access variables in the cloned model use `cloned_model["var_name"]`.


Examples
--------
.. code-block:: python

import pymc as pm
import copy

with pm.Model() as m:
p = pm.Beta("p", 1, 1)
x = pm.Bernoulli("x", p=p, shape=(3,))

clone_m = copy.copy(m)

# Access cloned variables by name
clone_x = clone_m["x"]

# z will be part of clone_m but not m
z = pm.Deterministic("z", clone_x + 1)
"""
from pymc.model.fgraph import clone_model

return clone_model(self)

def replace_rvs_by_values(
self,
graphs: Sequence[TensorVariable],
Expand Down
10 changes: 10 additions & 0 deletions pymc/model/fgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import warnings

from copy import copy, deepcopy

import pytensor
Expand Down Expand Up @@ -158,6 +160,14 @@ def fgraph_from_model(
"Nested sub-models cannot be converted to fgraph. Convert the parent model instead"
)

if any(
("_rotated_" in var_name or "_hsgp_coeffs_" in var_name) for var_name in model.named_vars
):
warnings.warn(
"Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883",
UserWarning,
)

# Collect PyTensor variables
rvs_to_values = model.rvs_to_values
rvs = list(rvs_to_values.keys())
Expand Down
61 changes: 61 additions & 0 deletions tests/model/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import copy
import pickle
import threading
import traceback
Expand Down Expand Up @@ -1761,3 +1762,63 @@ def test_graphviz_call_function(self, var_names, filenames) -> None:
figsize=None,
dpi=300,
)


class TestModelCopy:
@staticmethod
def simple_model() -> pm.Model:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to use staticmethods as these aren't used in more than one single test?

with pm.Model() as simple_model:
error = pm.HalfNormal("error", 0.5)
alpha = pm.Normal("alpha", 0, 1)
pm.Normal("y", alpha, error)
return simple_model

@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
def test_copy_model(self, copy_method) -> None:
simple_model = self.simple_model()
copy_simple_model = copy_method(simple_model)

with simple_model:
simple_model_prior_predictive = pm.sample_prior_predictive(samples=1, random_seed=42)

with copy_simple_model:
copy_simple_model_prior_predictive = pm.sample_prior_predictive(
samples=1, random_seed=42
)

simple_model_prior_predictive_mean = simple_model_prior_predictive["prior"]["y"].mean(
("chain", "draw")
)
copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive["prior"][
"y"
].mean(("chain", "draw"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to take them mean, now that it's a single value. Just retrieve it with simple_model_prior_preictive.prior["y"].values


assert np.isclose(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can check exact equality, since the draws are exactly the same

simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean
)

@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
def test_guassian_process_copy_failure(self, copy_method) -> None:
with pm.Model() as gaussian_process_model:
ell = pm.Gamma("ell", alpha=2, beta=1)
cov = 2 * pm.gp.cov.ExpQuad(1, ell)
gp = pm.gp.Latent(cov_func=cov)
f = gp.prior("f", X=np.arange(10)[:, None])
pm.Normal("y", f * 2)

with pytest.warns(
UserWarning,
match="Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883",
):
copy_method(gaussian_process_model)

@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
def test_adding_deterministics_to_clone(self, copy_method) -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check can be done in the first test. That way you can also confirm the prior_predictive["z"] draw is what you expect (and only exists in the cloned model)

simple_model = self.simple_model()
clone_model = copy_method(simple_model)

with clone_model:
z = pm.Deterministic("z", clone_model["alpha"] + 1)

assert "z" in clone_model.named_vars
assert "z" not in simple_model.named_vars
Loading