-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Allow copy and deepcopy of PYMC models #7492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 6 commits
b34742d
fe4e0c5
bcb4309
33c5766
88fde25
90419cb
07106ec
fb00f85
d057a9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1571,6 +1571,40 @@ def __getitem__(self, key): | |||||||||||
def __contains__(self, key): | ||||||||||||
return key in self.named_vars or self.name_for(key) in self.named_vars | ||||||||||||
|
||||||||||||
def __copy__(self): | ||||||||||||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
return self.copy() | ||||||||||||
|
||||||||||||
def __deepcopy__(self, _): | ||||||||||||
return self.copy() | ||||||||||||
|
||||||||||||
def copy(self): | ||||||||||||
""" | ||||||||||||
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph. | ||||||||||||
Constants are not cloned and if guassian process variables are detected then a warning will be triggered. | ||||||||||||
|
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph. | |
Constants are not cloned and if guassian process variables are detected then a warning will be triggered. | |
Clone the model. | |
To access variables in the cloned model use `cloned_model["var_name"]`. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import copy | ||
import pickle | ||
import threading | ||
import traceback | ||
|
@@ -1761,3 +1762,63 @@ def test_graphviz_call_function(self, var_names, filenames) -> None: | |
figsize=None, | ||
dpi=300, | ||
) | ||
|
||
|
||
class TestModelCopy: | ||
@staticmethod | ||
def simple_model() -> pm.Model: | ||
|
||
with pm.Model() as simple_model: | ||
error = pm.HalfNormal("error", 0.5) | ||
alpha = pm.Normal("alpha", 0, 1) | ||
pm.Normal("y", alpha, error) | ||
return simple_model | ||
|
||
@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy)) | ||
def test_copy_model(self, copy_method) -> None: | ||
simple_model = self.simple_model() | ||
copy_simple_model = copy_method(simple_model) | ||
|
||
with simple_model: | ||
simple_model_prior_predictive = pm.sample_prior_predictive(samples=1, random_seed=42) | ||
|
||
with copy_simple_model: | ||
copy_simple_model_prior_predictive = pm.sample_prior_predictive( | ||
samples=1, random_seed=42 | ||
) | ||
|
||
simple_model_prior_predictive_mean = simple_model_prior_predictive["prior"]["y"].mean( | ||
("chain", "draw") | ||
) | ||
copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive["prior"][ | ||
"y" | ||
].mean(("chain", "draw")) | ||
|
||
|
||
assert np.isclose( | ||
|
||
simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean | ||
) | ||
|
||
@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy)) | ||
def test_guassian_process_copy_failure(self, copy_method) -> None: | ||
with pm.Model() as gaussian_process_model: | ||
ell = pm.Gamma("ell", alpha=2, beta=1) | ||
cov = 2 * pm.gp.cov.ExpQuad(1, ell) | ||
gp = pm.gp.Latent(cov_func=cov) | ||
f = gp.prior("f", X=np.arange(10)[:, None]) | ||
pm.Normal("y", f * 2) | ||
|
||
with pytest.warns( | ||
UserWarning, | ||
match="Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883", | ||
): | ||
copy_method(gaussian_process_model) | ||
|
||
@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy)) | ||
def test_adding_deterministics_to_clone(self, copy_method) -> None: | ||
|
||
simple_model = self.simple_model() | ||
clone_model = copy_method(simple_model) | ||
|
||
with clone_model: | ||
z = pm.Deterministic("z", clone_model["alpha"] + 1) | ||
|
||
assert "z" in clone_model.named_vars | ||
assert "z" not in simple_model.named_vars |
Uh oh!
There was an error while loading. Please reload this page.