Skip to content

Commit 015cf23

Browse files
solver feature (#29)
* solver feature * need scipy now
1 parent 82599f2 commit 015cf23

File tree

4 files changed

+90
-3
lines changed

4 files changed

+90
-3
lines changed

setup.cfg

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@ package_dir =
3030
# DON'T CHANGE THE FOLLOWING LINE! IT WILL BE UPDATED BY PYSCAFFOLD!
3131
setup_requires = pyscaffold>=3.2a0,<3.3a0
3232
# Add here dependencies of your project (semicolon/line-separated), e.g.
33-
install_requires = numpy
34-
matplotlib
35-
pyerf
33+
install_requires =
34+
numpy
35+
scipy
36+
matplotlib
37+
pyerf
3638

3739
# The usage of test_requires is discouraged, see `Dependency Management` docs
3840
# tests_require = pytest; pytest-cov

src/mud/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,5 @@
99
__version__ = 'unknown'
1010
finally:
1111
del get_distribution, DistributionNotFound
12+
13+
from .base import DensityProblem # noqa: F401

src/mud/base.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import numpy as np
2+
from scipy.stats import distributions as dist
3+
from scipy.stats import gaussian_kde as gkde
4+
5+
6+
class DensityProblem(object):
7+
def __init__(self, X, y, domain=None):
8+
self.X = X
9+
self.y = y
10+
self.domain = np.array(domain)
11+
self._up = None
12+
self._pr = None
13+
self._in = None
14+
self._ob = None
15+
16+
def set_observed(self, distribution=dist.norm()):
17+
self._ob = distribution.pdf(self.y).prod(axis=1)
18+
19+
def set_initial(self, distribution=None):
20+
if distribution is None: # assume standard normal by default
21+
if self.domain is not None: # assume uniform if domain specified
22+
mn = np.min(self.domain, axis=1)
23+
mx = np.max(self.domain, axis=1)
24+
distribution = dist.uniform(loc=mn, scale=mx - mn)
25+
distribution = dist.norm()
26+
initial_dist = distribution
27+
self._in = initial_dist.pdf(self.X).prod(axis=1)
28+
29+
def set_predicted(self, distribution=None):
30+
if distribution is None:
31+
distribution = gkde(self.y.T)
32+
pred_pdf = distribution.pdf(self.y.T).T
33+
else:
34+
pred_pdf = distribution.pdf(self.y)
35+
self._pr = pred_pdf
36+
37+
def fit(self):
38+
if not self._in:
39+
self.set_initial()
40+
self._pr = None
41+
if not self._pr:
42+
self.set_predicted()
43+
if not self._ob:
44+
self.set_observed()
45+
46+
up_pdf = np.divide(np.multiply(self._in, self._ob), self._pr)
47+
self._up = up_pdf
48+
49+
def mud_point(self):
50+
if self._up is None:
51+
self.fit()
52+
m = np.argmax(self._up)
53+
return self.X[m, :]

src/mud/funs.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import numpy as np
1111

1212
from mud import __version__
13+
from mud import DensityProblem
1314

1415
__author__ = "Mathematical Michael"
1516
__copyright__ = "Mathematical Michael"
@@ -264,5 +265,34 @@ def iterate(A, b, y, initial_mean, initial_cov,
264265
return chain
265266

266267

268+
def mud_problem(lam, qoi, qoi_true, domain, sd=0.05, num_obs=None):
269+
"""
270+
Wrapper around mud problem, takes in raw qoi + synthetic data and
271+
performs WME transformation, instantiates solver object
272+
"""
273+
if lam.ndim == 1:
274+
lam = lam.reshape(-1, 1)
275+
276+
if qoi.ndim == 1:
277+
qoi = qoi.reshape(-1, 1)
278+
dim_output = qoi.shape[1]
279+
280+
if num_obs is None:
281+
num_obs = dim_output
282+
elif num_obs < 1:
283+
raise ValueError("num_obs must be >= 1")
284+
elif num_obs > dim_output:
285+
raise ValueError("num_obs must be <= dim(qoi)")
286+
287+
# this is our data processing step.
288+
data = qoi_true[0:num_obs] + np.random.randn(num_obs) * sd
289+
q = wme(qoi[:, 0:num_obs], data, sd).reshape(-1, 1)
290+
291+
# this just implements density-based solutions + mud point method
292+
d = DensityProblem(lam, q, domain)
293+
# d.fit() # optional. will compute if invoked while empty.
294+
return d
295+
296+
267297
if __name__ == "__main__":
268298
run()

0 commit comments

Comments
 (0)