|
10 | 10 | import numpy as np |
11 | 11 |
|
12 | 12 | from mud import __version__ |
| 13 | +from mud import DensityProblem |
13 | 14 |
|
14 | 15 | __author__ = "Mathematical Michael" |
15 | 16 | __copyright__ = "Mathematical Michael" |
@@ -264,5 +265,34 @@ def iterate(A, b, y, initial_mean, initial_cov, |
264 | 265 | return chain |
265 | 266 |
|
266 | 267 |
|
| 268 | +def mud_problem(lam, qoi, qoi_true, domain, sd=0.05, num_obs=None): |
| 269 | + """ |
| 270 | + Wrapper around mud problem, takes in raw qoi + synthetic data and |
| 271 | + performs WME transformation, instantiates solver object |
| 272 | + """ |
| 273 | + if lam.ndim == 1: |
| 274 | + lam = lam.reshape(-1, 1) |
| 275 | + |
| 276 | + if qoi.ndim == 1: |
| 277 | + qoi = qoi.reshape(-1, 1) |
| 278 | + dim_output = qoi.shape[1] |
| 279 | + |
| 280 | + if num_obs is None: |
| 281 | + num_obs = dim_output |
| 282 | + elif num_obs < 1: |
| 283 | + raise ValueError("num_obs must be >= 1") |
| 284 | + elif num_obs > dim_output: |
| 285 | + raise ValueError("num_obs must be <= dim(qoi)") |
| 286 | + |
| 287 | + # this is our data processing step. |
| 288 | + data = qoi_true[0:num_obs] + np.random.randn(num_obs) * sd |
| 289 | + q = wme(qoi[:, 0:num_obs], data, sd).reshape(-1, 1) |
| 290 | + |
| 291 | + # this just implements density-based solutions + mud point method |
| 292 | + d = DensityProblem(lam, q, domain) |
| 293 | +# d.fit() # optional. will compute if invoked while empty. |
| 294 | + return d |
| 295 | + |
| 296 | + |
267 | 297 | if __name__ == "__main__": |
268 | 298 | run() |
0 commit comments