Skip to content
This repository was archived by the owner on Feb 27, 2026. It is now read-only.

Commit 6193583

Browse files
committed
Fix evaluate method for GMM
1 parent ade1dfb commit 6193583

File tree

2 files changed

+10
-11
lines changed

2 files changed

+10
-11
lines changed

pycave/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
word2vec

pycave/bayes/gmm/engine.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,10 @@ def after_epoch(self, _):
4949
def after_training(self):
5050
self.cache = None
5151

52+
def evaluate(self, data, metrics=None, **kwargs):
53+
metrics = {'neg_log_likelihood': lambda x: x}
54+
return super().evaluate(data, metrics=metrics, **kwargs)
55+
5256
def train_batch(self, data, eps=0.01):
5357
# E-step: compute responsibilities
5458
responsibilities, nll = self.model(data)
@@ -87,14 +91,8 @@ def collate_losses(self, _):
8791
nll = self.cache['neg_log_likelihood']
8892
return {'neg_log_likelihood': nll}
8993

90-
def collate_predictions(self, predictions):
91-
sample = predictions[0]
92-
93-
if isinstance(sample, dict):
94-
# Only negative log-likelihood
95-
nll_sum = sum([p['nll'] for p in predictions])
96-
n = sum([p['n'] for p in predictions])
97-
return {'neg_log_likelihood': nll_sum / n}
98-
99-
# In this case, we return the responsibilities
100-
return torch.cat(predictions)
94+
def collate_evals(self, evals):
95+
# Only negative log-likelihood
96+
nll_sum = sum([p['nll'] for p in evals])
97+
n = sum([p['n'] for p in evals])
98+
return torch.as_tensor(nll_sum / n)

0 commit comments

Comments
 (0)