Skip to content

Commit 306d733

Browse files
authored
Merge pull request #1992 from cornellius-gp/pickle_constant_mean
Make it possible to pickle constant mean with prior
2 parents 1456e6f + 9a3af18 commit 306d733

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

gpytorch/means/constant_mean.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@ def __init__(self, prior=None, batch_shape=torch.Size(), **kwargs):
1212
self.batch_shape = batch_shape
1313
self.register_parameter(name="constant", parameter=torch.nn.Parameter(torch.zeros(*batch_shape, 1)))
1414
if prior is not None:
15-
self.register_prior("mean_prior", prior, "constant")
15+
self.register_prior("mean_prior", prior, self._constant_param, self._constant_closure)
16+
17+
def _constant_param(self, m):
18+
return m.constant
19+
20+
def _constant_closure(self, m, value):
21+
return m.constant.data.fill_(value)
1622

1723
def forward(self, input):
1824
if input.shape[:-2] == self.batch_shape:

test/means/test_constant_mean.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
11
#!/usr/bin/env python3
22

3+
import pickle
34
import unittest
45

56
import torch
67

78
from gpytorch.means import ConstantMean
9+
from gpytorch.priors import NormalPrior
810
from gpytorch.test.base_mean_test_case import BaseMeanTestCase
911

1012

1113
class TestConstantMean(BaseMeanTestCase, unittest.TestCase):
1214
def create_mean(self):
1315
return ConstantMean()
1416

17+
def test_prior(self):
18+
prior = NormalPrior(0.0, 1.0)
19+
mean = ConstantMean(prior=prior)
20+
self.assertEqual(mean.mean_prior, prior)
21+
pickle.loads(pickle.dumps(mean)) # Should be able to pickle and unpickle with a prior
22+
mean._constant_closure(mean, 1.234)
23+
self.assertAlmostEqual(mean.constant.item(), 1.234)
24+
1525

1626
class TestConstantMeanBatch(BaseMeanTestCase, unittest.TestCase):
1727
batch_shape = torch.Size([3])

0 commit comments

Comments
 (0)