@@ -49,14 +49,22 @@ def after_epoch(self, _):
4949 def after_training (self ):
5050 self .cache = None
5151
52- def evaluate (self , data , metrics = None , ** kwargs ):
53- metrics = {'neg_log_likelihood' : lambda x : x }
54- return super ().evaluate (data , metrics = metrics , ** kwargs )
52+ def evaluate (self , data , reduction = 'mean' , ** kwargs ):
53+ def reduce (x ):
54+ if reduction == 'mean' :
55+ return x .mean ()
56+ if reduction == 'sum' :
57+ return x .sum ()
58+ return x
59+
60+ metrics = {'neg_log_likelihood' : reduce }
61+ result = super ().evaluate (data , metrics = metrics , ** kwargs )
62+ return result ['neg_log_likelihood' ]
5563
5664 def train_batch (self , data , eps = 0.01 ):
5765 # E-step: compute responsibilities
5866 responsibilities , nll = self .model (data )
59- nll_ = nll .item () / data . size ( 0 )
67+ nll_ = nll .mean (). item ( )
6068
6169 # M-step: maximize
6270 gaussian_max = self .model .gaussian .maximize (data , responsibilities , self .requires_batching )
@@ -78,10 +86,7 @@ def train_batch(self, data, eps=0.01):
7886 self .cache ['eps' ] = eps
7987
8088 def eval_batch (self , data ):
81- return {
82- 'nll' : self .model (data )[1 ].item (), # nll
83- 'n' : data .data .size (0 )
84- }
89+ return self .model (data )[1 ] # NLL for all data samples
8590
8691 def predict_batch (self , data ):
8792 # Get responsibilities and normalize them to get a distribution over components
@@ -90,9 +95,3 @@ def predict_batch(self, data):
9095 def collate_losses (self , _ ):
9196 nll = self .cache ['neg_log_likelihood' ]
9297 return {'neg_log_likelihood' : nll }
93-
94- def collate_evals (self , evals ):
95- # Only negative log-likelihood
96- nll_sum = sum ([p ['nll' ] for p in evals ])
97- n = sum ([p ['n' ] for p in evals ])
98- return torch .as_tensor (nll_sum / n )
0 commit comments