Skip to content
This repository was archived by the owner on Feb 27, 2026. It is now read-only.

Commit 4e97471

Browse files
committed
Allow GMM evaluation to return NLL for individual datapoints
1 parent 6193583 commit 4e97471

File tree

3 files changed

+22
-18
lines changed

3 files changed

+22
-18
lines changed

pycave/bayes/_internal/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def log_responsibilities(log_probs, comp_priors, return_log_likelihood=False):
8080
log_resp = posterior - evidence
8181

8282
if return_log_likelihood:
83-
return log_resp, evidence.sum()
83+
return log_resp, evidence
8484
return log_resp
8585

8686

pycave/bayes/gmm/engine.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -49,14 +49,22 @@ def after_epoch(self, _):
4949
def after_training(self):
5050
self.cache = None
5151

52-
def evaluate(self, data, metrics=None, **kwargs):
53-
metrics = {'neg_log_likelihood': lambda x: x}
54-
return super().evaluate(data, metrics=metrics, **kwargs)
52+
def evaluate(self, data, reduction='mean', **kwargs):
53+
def reduce(x):
54+
if reduction == 'mean':
55+
return x.mean()
56+
if reduction == 'sum':
57+
return x.sum()
58+
return x
59+
60+
metrics = {'neg_log_likelihood': reduce}
61+
result = super().evaluate(data, metrics=metrics, **kwargs)
62+
return result['neg_log_likelihood']
5563

5664
def train_batch(self, data, eps=0.01):
5765
# E-step: compute responsibilities
5866
responsibilities, nll = self.model(data)
59-
nll_ = nll.item() / data.size(0)
67+
nll_ = nll.mean().item()
6068

6169
# M-step: maximize
6270
gaussian_max = self.model.gaussian.maximize(data, responsibilities, self.requires_batching)
@@ -78,10 +86,7 @@ def train_batch(self, data, eps=0.01):
7886
self.cache['eps'] = eps
7987

8088
def eval_batch(self, data):
81-
return {
82-
'nll': self.model(data)[1].item(), # nll
83-
'n': data.data.size(0)
84-
}
89+
return self.model(data)[1] # NLL for all data samples
8590

8691
def predict_batch(self, data):
8792
# Get responsibilities and normalize them to get a distribution over components
@@ -90,9 +95,3 @@ def predict_batch(self, data):
9095
def collate_losses(self, _):
9196
nll = self.cache['neg_log_likelihood']
9297
return {'neg_log_likelihood': nll}
93-
94-
def collate_evals(self, evals):
95-
# Only negative log-likelihood
96-
nll_sum = sum([p['nll'] for p in evals])
97-
n = sum([p['n'] for p in evals])
98-
return torch.as_tensor(nll_sum / n)

pycave/bayes/gmm/model.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,11 @@ class may be used to find clusters whenever you expect data to be generated from
3838
eps: float, default: 0.01
3939
The minimum per-datapoint difference in the negative log-likelihood to consider a
4040
model "better", thus indicating convergence.
41+
42+
`evaluate(...)`
43+
reduction: str, default: 'mean'
44+
The reduction performed for the negative log-likelihood as for common PyTorch metrics.
45+
Must be one of ['mean', 'sum', 'none'].
4146
"""
4247

4348
__engine__ = GMMEngine
@@ -118,14 +123,14 @@ def forward(self, data):
118123
-------
119124
torch.Tensor [N, K]
120125
The responsibilities for each datapoint and component (number of components K).
121-
torch.Tensor [1]
122-
The negative log-likelihood of the data.
126+
torch.Tensor [N]
127+
The negative log-likelihood for all data samples.
123128
"""
124129
probs = self.gaussian.evaluate(data, log=True)
125130
log_resp, log_likeli = log_responsibilities(
126131
probs, self.component_weights, return_log_likelihood=True
127132
)
128-
return log_resp.exp(), -log_likeli
133+
return log_resp.exp(), -log_likeli.squeeze(-1)
129134

130135
def sample(self, n, return_components=False):
131136
"""

0 commit comments

Comments
 (0)