Skip to content

Commit 900d8b9

Browse files
committed
implemented simplex constraint
1 parent 6e7b75b commit 900d8b9

File tree

6 files changed

+586
-11
lines changed

6 files changed

+586
-11
lines changed

docs/theory.md

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,31 @@ The standard Gibbs sampler in `pybmc` assumes a Gaussian likelihood and conjugat
5353

5454
### Gibbs Sampler with Simplex Constraints
5555

56-
`pybmc` also provides a Gibbs sampler that enforces simplex constraints on the model weights (i.e., \(\sum w_k = 1\) and \(w_k \ge 0\)). This is achieved by performing a random walk in the space of the transformed parameters and using a Metropolis-Hastings step to accept or reject proposals that fall outside the valid simplex region.
56+
`pybmc` also provides a Gibbs sampler that enforces simplex constraints on the model weights (i.e., \(\sum w_k = 1\) and \(w_k \ge 0\)). This is achieved by performing a random walk in the space of the transformed parameters and using a Metropolis-Hastings step to accept or reject proposals that fall outside the valid simplex region.
57+
58+
#### When to Use Each Mode
59+
60+
| Mode | Description | Use When |
61+
|------|-------------|----------|
62+
| **Unconstrained** (default) | Weights can take any real value | Maximum flexibility; some models may get negative weights to cancel out biases |
63+
| **Simplex** | Weights satisfy \(w_k \ge 0\) and \(\sum w_k = 1\) | You need interpretable weights that form a proper mixture; predictions should stay within the range of individual models |
64+
65+
#### Simplex Constraint Implementation
66+
67+
The simplex constraint is enforced through a Metropolis-within-Gibbs algorithm. In the SVD-reduced coefficient space, the relationship between the regression coefficients \(\boldsymbol{\beta}\) and the model weights \(\boldsymbol{\omega}\) is:
68+
69+
\[
70+
\omega_k = \sum_{j=1}^m \beta_j \hat{V}_{jk} + \frac{1}{K}
71+
\]
72+
73+
where \(\hat{V}\) contains the (normalized) right singular vectors and \(K\) is the number of models. The term \(\frac{1}{K}\) represents the equal-weight baseline.
74+
75+
At each iteration, the algorithm:
76+
77+
1. **Proposes** a new coefficient vector \(\boldsymbol{\beta}^*\) from a multivariate normal centered on the current value.
78+
2. **Projects** the proposal to weight space via \(\boldsymbol{\omega}^* = \boldsymbol{\beta}^* \hat{V} + \frac{1}{K}\).
79+
3. **Rejects** the proposal if any \(\omega_k^* < 0\) (the sum-to-one constraint is automatically satisfied by the SVD structure and the \(\frac{1}{K}\) offset).
80+
4. **Accepts** valid proposals with probability \(\min\!\bigl(1,\; \exp\!\bigl[\bigl(\ell(\boldsymbol{\beta}^*) - \ell(\boldsymbol{\beta})\bigr) / \sigma^2\bigr]\bigr)\), where \(\ell\) is the log-likelihood.
81+
5. **Samples** the error variance \(\sigma^2\) from its inverse-gamma full conditional.
82+
83+
The `burn` parameter controls the number of burn-in iterations discarded before collecting samples, and the `stepsize` parameter scales the proposal covariance matrix to tune the acceptance rate.

docs/usage.md

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,83 @@ With the data prepared and the model orthogonalized, we can train the model comb
112112
bmc.train(training_options={"iterations": 50000, "sampler": "gibbs_sampling"})
113113
```
114114

115+
### Simplex Constraint Mode
116+
117+
By default, `pybmc` uses an unconstrained Gibbs sampler where model weights can take
118+
any real value. If you want to enforce that the weights lie on the **probability
119+
simplex** — meaning each weight is between 0 and 1 and the weights sum to 1 — you can
120+
enable the simplex constraint mode.
121+
122+
!!! tip "When to Use Simplex Constraints"
123+
Use simplex constraints when you want the model combination to behave as a
124+
**proper weighted average** of the constituent models. This is appropriate when:
125+
126+
- You want each model to contribute non-negatively to the prediction.
127+
- The combined prediction should remain within the range spanned by the individual models.
128+
- Physical interpretability of the weights matters for your application.
129+
130+
The unconstrained mode is more flexible and may yield better predictive performance
131+
when some models systematically over- or under-predict, since negative weights can
132+
partially cancel out biased models.
133+
134+
There are two ways to enable simplex constraints:
135+
136+
**Option 1: Set at initialization (recommended when you always want simplex)**
137+
138+
```python
139+
bmc = BayesianModelCombination(
140+
models_list=["FRDM12", "HFB24", "D1M", "UNEDF1", "BCPM"],
141+
data_dict=data_dict,
142+
truth_column_name="AME2020",
143+
constraint="simplex", # <-- weights constrained to [0, 1], sum to 1
144+
)
145+
146+
bmc.orthogonalize("BE", train_df, components_kept=3)
147+
bmc.train(training_options={
148+
"iterations": 50000,
149+
"burn": 10000, # burn-in iterations for the Metropolis step
150+
"stepsize": 0.001, # proposal step size
151+
})
152+
```
153+
154+
**Option 2: Override per training call**
155+
156+
```python
157+
# Initialize with default unconstrained mode
158+
bmc = BayesianModelCombination(
159+
models_list=["FRDM12", "HFB24", "D1M", "UNEDF1", "BCPM"],
160+
data_dict=data_dict,
161+
truth_column_name="AME2020",
162+
)
163+
164+
bmc.orthogonalize("BE", train_df, components_kept=3)
165+
166+
# Override to simplex for this specific training run
167+
bmc.train(training_options={
168+
"iterations": 50000,
169+
"sampler": "simplex",
170+
"burn": 10000,
171+
"stepsize": 0.001,
172+
})
173+
```
174+
175+
### Inspecting Model Weights
176+
177+
After training, you can inspect the inferred model weights using `get_weights()`:
178+
179+
```python
180+
# Get a summary (mean, std, median per model)
181+
summary = bmc.get_weights()
182+
for model, mean_w, std_w in zip(summary["models"], summary["mean"], summary["std"]):
183+
print(f" {model}: {mean_w:.4f} ± {std_w:.4f}")
184+
185+
# Get the full weight matrix (n_samples × n_models) for custom analysis
186+
weight_matrix = bmc.get_weights(summary=False)
187+
```
188+
189+
In simplex mode, every row of the weight matrix is guaranteed to satisfy
190+
\(w_k \ge 0\) and \(\sum_k w_k = 1\).
191+
115192
## 4. Make Predictions
116193

117194
After training, we can use the `predict` method to generate predictions with uncertainty quantification. The method returns the full posterior draws, as well as DataFrames for the lower, median, and upper credible intervals.

pybmc/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from .data import Dataset
1212
from .bmc import BayesianModelCombination
13-
from .inference_utils import gibbs_sampler, USVt_hat_extraction
13+
from .inference_utils import gibbs_sampler, gibbs_sampler_simplex, USVt_hat_extraction
1414
from .sampling_utils import coverage
1515

1616

@@ -19,6 +19,7 @@
1919
"Dataset",
2020
"BayesianModelCombination",
2121
"gibbs_sampler",
22+
"gibbs_sampler_simplex",
2223
"USVt_hat_extraction",
2324
"coverage",
2425
]

pybmc/bmc.py

Lines changed: 87 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import matplotlib.pyplot as plt
44
from sklearn.model_selection import train_test_split
55
import os
6-
from .inference_utils import gibbs_sampler, USVt_hat_extraction
6+
from .inference_utils import gibbs_sampler, gibbs_sampler_simplex, USVt_hat_extraction
77
from .sampling_utils import coverage, rndm_m_random_calculator
88

99

@@ -16,26 +16,41 @@ class BayesianModelCombination:
1616
+ Predictions for certain isotopes.
1717
"""
1818

19-
def __init__(self, models_list, data_dict, truth_column_name, weights=None):
19+
VALID_CONSTRAINTS = ("unconstrained", "simplex")
20+
21+
def __init__(self, models_list, data_dict, truth_column_name, weights=None, constraint="unconstrained"):
2022
"""
2123
Initialize the BayesianModelCombination class.
2224
2325
:param models_list: List of model names
2426
:param data_dict: Dictionary from `load_data()` where each key is a model name and each value is a DataFrame of properties
2527
:param truth_column_name: Name of the column containing the truth values.
2628
:param weights: Optional initial weights for the models.
29+
:param constraint: Weight constraint mode. Options:
30+
- ``"unconstrained"`` (default): No constraints on model weights.
31+
- ``"simplex"``: Forces weights to lie on the probability simplex
32+
(each weight between 0 and 1, weights sum to 1). Uses a
33+
Metropolis-within-Gibbs sampler to enforce the constraint.
2734
"""
2835

2936
if not isinstance(models_list, list) or not all(isinstance(model, str) for model in models_list):
3037
raise ValueError("The 'models' should be a list of model names (strings) for Bayesian Combination.")
3138
if not isinstance(data_dict, dict) or not all(isinstance(df, pd.DataFrame) for df in data_dict.values()):
3239
raise ValueError("The 'data_dict' should be a dictionary of pandas DataFrames, one per property.")
40+
if constraint not in self.VALID_CONSTRAINTS:
41+
raise ValueError(
42+
f"Invalid constraint '{constraint}'. "
43+
f"Must be one of {self.VALID_CONSTRAINTS}."
44+
)
3345

3446
self.data_dict = data_dict
3547
self.models_list = models_list
3648
self.models = [m for m in models_list if m != 'truth']
3749
self.weights = weights if weights is not None else None
3850
self.truth_column_name = truth_column_name
51+
self.constraint = constraint
52+
self.samples = None
53+
self.Vt_hat = None
3954

4055

4156
def orthogonalize(self, property, train_df, components_kept):
@@ -85,27 +100,61 @@ def train(self, training_options=None):
85100
"""
86101
Train the model combination using training data and optional training parameters.
87102
88-
:param training_data: Placeholder (not used).
89103
:param training_options: Dictionary of training options. Keys:
90104
- 'iterations': (int) Number of Gibbs iterations (default 50000)
105+
- 'sampler': (str) Override the constraint mode for this training run.
106+
``"unconstrained"`` or ``"simplex"``. If not provided, uses the
107+
instance-level ``self.constraint`` set at initialization.
91108
- 'b_mean_prior': (np.ndarray) Prior mean vector (default zeros)
109+
*(unconstrained sampler only)*
92110
- 'b_mean_cov': (np.ndarray) Prior covariance matrix (default diag(S_hat²))
111+
*(unconstrained sampler only)*
93112
- 'nu0_chosen': (float) Degrees of freedom for variance prior (default 1.0)
94113
- 'sigma20_chosen': (float) Prior variance (default 0.02)
114+
- 'burn': (int) Burn-in iterations (default 10000)
115+
*(simplex sampler only)*
116+
- 'stepsize': (float) Proposal step size (default 0.001)
117+
*(simplex sampler only)*
95118
"""
96119
if training_options is None:
97120
training_options = {}
98121

122+
# Determine which sampler to use: training_options overrides instance default
123+
sampler_mode = training_options.get('sampler', self.constraint)
124+
if sampler_mode not in self.VALID_CONSTRAINTS:
125+
raise ValueError(
126+
f"Invalid sampler '{sampler_mode}'. "
127+
f"Must be one of {self.VALID_CONSTRAINTS}."
128+
)
129+
99130
iterations = training_options.get('iterations', 50000)
100131
num_components = self.U_hat.shape[1]
101132
S_hat = self.S_hat
102-
103-
b_mean_prior = training_options.get('b_mean_prior', np.zeros(num_components))
104-
b_mean_cov = training_options.get('b_mean_cov', np.diag(S_hat**2))
105133
nu0_chosen = training_options.get('nu0_chosen', 1.0)
106134
sigma20_chosen = training_options.get('sigma20_chosen', 0.02)
107135

108-
self.samples = gibbs_sampler(self.centered_experiment_train, self.U_hat, iterations, [b_mean_prior, b_mean_cov, nu0_chosen, sigma20_chosen])
136+
if sampler_mode == "simplex":
137+
burn = training_options.get('burn', 10000)
138+
stepsize = training_options.get('stepsize', 0.001)
139+
self.samples = gibbs_sampler_simplex(
140+
self.centered_experiment_train,
141+
self.U_hat,
142+
self.Vt_hat,
143+
self.S_hat,
144+
iterations,
145+
[nu0_chosen, sigma20_chosen],
146+
burn=burn,
147+
stepsize=stepsize,
148+
)
149+
else:
150+
b_mean_prior = training_options.get('b_mean_prior', np.zeros(num_components))
151+
b_mean_cov = training_options.get('b_mean_cov', np.diag(S_hat**2))
152+
self.samples = gibbs_sampler(
153+
self.centered_experiment_train,
154+
self.U_hat,
155+
iterations,
156+
[b_mean_prior, b_mean_cov, nu0_chosen, sigma20_chosen],
157+
)
109158

110159

111160

@@ -185,7 +234,37 @@ def evaluate(self, domain_filter=None):
185234

186235
return coverage(np.arange(0, 101, 5), rndm_m, df, truth_column=self.truth_column_name)
187236

188-
237+
def get_weights(self, summary=True):
238+
"""
239+
Compute model weights from posterior samples.
240+
241+
Converts the sampled coefficient vectors (beta) into model weights
242+
using the transformation ``omega = beta @ Vt_hat + 1/M``, where M is
243+
the number of models. In simplex-constrained mode, all weights are
244+
guaranteed to be non-negative and sum to 1.
245+
246+
:param summary: If True (default), return a dictionary with
247+
``'mean'``, ``'std'``, ``'median'`` arrays keyed by statistic.
248+
If False, return the full ``(n_samples, n_models)`` weight matrix.
249+
:return: Weight summary dict or full weight matrix.
250+
:raises ValueError: If ``train()`` has not been called.
251+
"""
252+
if self.samples is None or self.Vt_hat is None:
253+
raise ValueError("Must call `orthogonalize()` and `train()` before getting weights.")
254+
255+
betas = self.samples[:, :-1]
256+
n_models = self.Vt_hat.shape[1]
257+
default_weights = np.full(n_models, 1.0 / n_models)
258+
weight_matrix = betas @ self.Vt_hat + default_weights
259+
260+
if summary:
261+
return {
262+
"mean": np.mean(weight_matrix, axis=0),
263+
"std": np.std(weight_matrix, axis=0),
264+
"median": np.median(weight_matrix, axis=0),
265+
"models": self.models,
266+
}
267+
return weight_matrix
189268

190269

191270

pybmc/sampling_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,9 @@ def rndm_m_random_calculator(filtered_model_predictions, samples, Vt_hat):
5454
np.random.seed(142858)
5555
rng = np.random.default_rng()
5656

57-
theta_rand_selected = rng.choice(samples, 10000, replace=False)
57+
n_draws = min(10000, len(samples))
58+
replace = len(samples) < 10000
59+
theta_rand_selected = rng.choice(samples, n_draws, replace=replace)
5860

5961
# Extract betas and noise std deviations
6062
betas = theta_rand_selected[:, :-1] # shape: (10000, num_models - 1)

0 commit comments

Comments
 (0)