|
23 | 23 | from typing import Any, Dict, Optional, Union |
24 | 24 |
|
25 | 25 | import numpy as np |
| 26 | +import pandas as pd |
26 | 27 | import pymc as pm |
27 | 28 | import pytensor.tensor as pt |
28 | 29 | from pymc_extras.prior import Prior |
@@ -65,9 +66,10 @@ class SpikeAndSlabPrior: |
65 | 66 | Creates a mixture prior with a point mass at zero (spike) and a diffuse |
66 | 67 | normal distribution (slab), implemented as: |
67 | 68 |
|
68 | | - β_j = γ_j × β_j^raw |
69 | | -
|
70 | | - where γ_j ∈ [0,1] is a relaxed indicator and β_j^raw ~ N(0, σ_slab²). |
| 69 | + .. math:: |
| 70 | + \beta_{j} = \gamma_{j} \cdot \beta_{j}^{\text{raw}} \\ |
| 71 | + \beta_{j}^{\text{raw}} \sim \mathcal{N}(0, \sigma_{\text{slab}}^{2}), \qquad |
| 72 | + \gamma_{j} \in [0,1]. |
71 | 73 |
|
72 | 74 | Parameters |
73 | 75 | ---------- |
@@ -145,9 +147,9 @@ class HorseshoePrior: |
145 | 147 | Provides continuous shrinkage with heavy tails, allowing strong signals |
146 | 148 | to escape shrinkage while weak signals are dampened: |
147 | 149 |
|
148 | | - β_j = τ · λ̃_j · β_j^raw |
149 | | -
|
150 | | - where λ̃_j = √(c²λ_j² / (c² + τ²λ_j²)) is the regularized local shrinkage. |
| 150 | + .. math:: |
| 151 | + \beta_{j} & = \tau \cdot \lambda_{j} \cdot \beta_{j}^{raw} \\ |
| 152 | + \lambda_{j} & = \sqrt{ \dfrac{c^{2}\lambda_{j}^{2}}{c^{2} + \tau^{2}\lambda_{j}^{2}} } |
151 | 153 |
|
152 | 154 | Parameters |
153 | 155 | ---------- |
@@ -423,7 +425,7 @@ def create_prior( |
423 | 425 |
|
424 | 426 | def get_inclusion_probabilities( |
425 | 427 | self, idata, param_name: str, threshold: float = 0.5 |
426 | | - ) -> Dict[str, np.ndarray]: |
| 428 | + ) -> pd.DataFrame: |
427 | 429 | """ |
428 | 430 | Extract variable inclusion probabilities from fitted model. |
429 | 431 |
|
@@ -472,17 +474,24 @@ def get_inclusion_probabilities( |
472 | 474 | gamma = az.extract(idata.posterior[gamma_name]) |
473 | 475 |
|
474 | 476 | # Compute inclusion probabilities |
475 | | - probabilities = (gamma > threshold).mean(dim="sample").values |
476 | | - gamma_mean = gamma.mean(dim="sample").values |
| 477 | + probabilities = (gamma > threshold).mean(dim="sample").to_array() |
| 478 | + gamma_mean = gamma.mean(dim="sample").to_array() |
477 | 479 | selected = probabilities > threshold |
478 | 480 |
|
479 | | - return { |
| 481 | + summary = { |
480 | 482 | "probabilities": probabilities, |
481 | 483 | "selected": selected, |
482 | 484 | "gamma_mean": gamma_mean, |
483 | 485 | } |
| 486 | + probs = summary["probabilities"].T |
| 487 | + df = pd.DataFrame(index=list(range(len(probs)))) |
| 488 | + |
| 489 | + df["prob"] = probs |
| 490 | + df["selected"] = summary["selected"].T |
| 491 | + df["gamma_mean"] = summary["gamma_mean"].T |
| 492 | + return df |
484 | 493 |
|
485 | | - def get_shrinkage_factors(self, idata, param_name: str) -> Dict[str, np.ndarray]: |
| 494 | + def get_shrinkage_factors(self, idata, param_name: str) -> pd.DataFrame: |
486 | 495 | """ |
487 | 496 | Extract shrinkage factors from horseshoe prior. |
488 | 497 |
|
@@ -524,17 +533,26 @@ def get_shrinkage_factors(self, idata, param_name: str) -> Dict[str, np.ndarray] |
524 | 533 | raise ValueError(f"Could not find '{lambda_tilde_name}' in posterior") |
525 | 534 |
|
526 | 535 | # Extract components |
527 | | - tau = az.extract(idata.posterior[tau_name]) |
528 | | - lambda_tilde = az.extract(idata.posterior[lambda_tilde_name]) |
| 536 | + tau = az.extract(idata.posterior[tau_name]).to_array() |
| 537 | + lambda_tilde = az.extract(idata.posterior[lambda_tilde_name]).to_array() |
529 | 538 |
|
530 | | - # Compute shrinkage factors |
531 | | - shrinkage_factors = (tau * lambda_tilde).mean(dim="sample").values |
| 539 | + shrinkage_factor = np.array( |
| 540 | + [tau[0, i] * lambda_tilde[0, :, :] for i in range(len(tau))] |
| 541 | + ) |
| 542 | + shrinkage_factor = shrinkage_factor.mean(axis=2) |
532 | 543 |
|
533 | | - return { |
534 | | - "shrinkage_factors": shrinkage_factors, |
535 | | - "tau": tau.mean().values, |
536 | | - "lambda_tilde": lambda_tilde.mean(dim="sample").values, |
| 544 | + summary = { |
| 545 | + "shrinkage_factors": shrinkage_factor, |
| 546 | + "tau": tau.mean(), |
| 547 | + "lambda_tilde": lambda_tilde.mean(dim=("sample")), |
537 | 548 | } |
| 549 | + probs = summary["shrinkage_factors"].T |
| 550 | + df = pd.DataFrame(index=list(range(len(probs)))) |
| 551 | + df["shrinkage_factor"] = probs |
| 552 | + |
| 553 | + df["lambda_tilde"] = summary["lambda_tilde"].T |
| 554 | + df["tau"] = np.mean(tau).item() |
| 555 | + return df |
538 | 556 |
|
539 | 557 |
|
540 | 558 | def create_variable_selection_prior( |
|
0 commit comments