-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Allow copy and deepcopy of PYMC models #7492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 5 commits
b34742d
fe4e0c5
bcb4309
33c5766
88fde25
90419cb
07106ec
fb00f85
d057a9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -1571,6 +1571,62 @@ | |||||||||||
def __contains__(self, key): | ||||||||||||
return key in self.named_vars or self.name_for(key) in self.named_vars | ||||||||||||
|
||||||||||||
def __copy__(self): | ||||||||||||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||||||||
""" | ||||||||||||
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph. | ||||||||||||
Constants are not cloned and if guassian process variables are detected then a warning will be triggered. | ||||||||||||
Examples | ||||||||||||
-------- | ||||||||||||
.. code-block:: python | ||||||||||||
import pymc as pm | ||||||||||||
import copy | ||||||||||||
with pm.Model() as m: | ||||||||||||
p = pm.Beta("p", 1, 1) | ||||||||||||
x = pm.Bernoulli("x", p=p, shape=(3,)) | ||||||||||||
clone_m = copy.copy(m) | ||||||||||||
# Access cloned variables by name | ||||||||||||
clone_x = clone_m["x"] | ||||||||||||
# z will be part of clone_m but not m | ||||||||||||
z = pm.Deterministic("z", clone_x + 1) | ||||||||||||
""" | ||||||||||||
from pymc.model.fgraph import clone_model | ||||||||||||
|
||||||||||||
return clone_model(self) | ||||||||||||
|
||||||||||||
def __deepcopy__(self, _): | ||||||||||||
""" | ||||||||||||
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph. | ||||||||||||
Constants are not cloned and if guassian process variables are detected then a warning will be triggered. | ||||||||||||
|
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph. | |
Constants are not cloned and if guassian process variables are detected then a warning will be triggered. | |
Clone the model. | |
To access variables in the cloned model use `cloned_model["var_name"]`. |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -11,6 +11,7 @@ | |||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||||
# See the License for the specific language governing permissions and | ||||||
# limitations under the License. | ||||||
import copy | ||||||
import pickle | ||||||
import threading | ||||||
import traceback | ||||||
|
@@ -1761,3 +1762,62 @@ def test_graphviz_call_function(self, var_names, filenames) -> None: | |||||
figsize=None, | ||||||
dpi=300, | ||||||
) | ||||||
|
||||||
|
||||||
class TestModelCopy: | ||||||
@staticmethod | ||||||
def simple_model() -> pm.Model: | ||||||
|
||||||
with pm.Model() as simple_model: | ||||||
error = pm.HalfNormal("error", 0.5) | ||||||
alpha = pm.Normal("alpha", 0, 1) | ||||||
pm.Normal("y", alpha, error) | ||||||
return simple_model | ||||||
|
||||||
@staticmethod | ||||||
def gp_model() -> pm.Model: | ||||||
with pm.Model() as gp_model: | ||||||
ell = pm.Gamma("ell", alpha=2, beta=1) | ||||||
cov = 2 * pm.gp.cov.ExpQuad(1, ell) | ||||||
gp = pm.gp.Latent(cov_func=cov) | ||||||
f = gp.prior("f", X=np.arange(10)[:, None]) | ||||||
pm.Normal("y", f * 2) | ||||||
return gp_model | ||||||
|
||||||
def test_copy_model(self) -> None: | ||||||
|
||||||
simple_model = self.simple_model() | ||||||
copy_simple_model = copy.copy(simple_model) | ||||||
deepcopy_simple_model = copy.deepcopy(simple_model) | ||||||
|
||||||
with simple_model: | ||||||
simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42) | ||||||
|
||||||
|
||||||
with copy_simple_model: | ||||||
copy_simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42) | ||||||
|
||||||
with deepcopy_simple_model: | ||||||
deepcopy_simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42) | ||||||
|
||||||
simple_model_prior_predictive_mean = simple_model_prior_predictive["prior"]["y"].mean( | ||||||
("chain", "draw") | ||||||
) | ||||||
copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive["prior"][ | ||||||
"y" | ||||||
].mean(("chain", "draw")) | ||||||
|
||||||
deepcopy_simple_model_prior_predictive_mean = deepcopy_simple_model_prior_predictive[ | ||||||
"prior" | ||||||
]["y"].mean(("chain", "draw")) | ||||||
|
||||||
assert np.isclose( | ||||||
|
||||||
simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean | ||||||
) | ||||||
assert np.isclose( | ||||||
simple_model_prior_predictive_mean, deepcopy_simple_model_prior_predictive_mean | ||||||
) | ||||||
|
||||||
def test_guassian_process_copy_failure(self) -> None: | ||||||
gaussian_process_model = self.gp_model() | ||||||
with pytest.warns(UserWarning): | ||||||
|
with pytest.warns(UserWarning): | |
with pytest.warns(UserWarning match=...): |
Uh oh!
There was an error while loading. Please reload this page.