@@ -94,11 +94,14 @@ def mog_loss_model(n_components, d_t):
9494 # Use logsumexp for numeric stability:
9595 # LL = C - log(sum(exp(-d2/(2*sig^2) + log(pi_i/sig^d))))
9696 def make_logloss (d2 , sig , pi ):
97+ #values = pi / K.pow(sig, d_t) * K.exp(-d2 / (2 * K.square(sig)))
98+ # return -K.log(K.sum(values, axis=-1))
99+
97100 # logsumexp doesn't exist in keras 2.4; simulate it
98101 values = - d2 / (2 * K .square (sig )) + K .log (pi / K .pow (sig , d_t ))
99102 # logsumexp(a,b,c) = log(exp(a)+exp(b)+exp(c)) = log((exp(a-k)+exp(b-k)+exp(c-k))*exp(k))
100103 # = log((exp(a-k)+exp(b-k)+exp(c-k))) + k
101- mx = K .max (values , axis = - 1 )
104+ mx = K .stop_gradient ( K . max (values , axis = - 1 ) )
102105 return - K .log (K .sum (K .exp (values - L .Reshape ((- 1 , 1 ))(mx )), axis = - 1 )) - mx
103106
104107 ll = L .Lambda (lambda dsp : make_logloss (* dsp ), output_shape = (1 ,))([d2 , sig , pi ])
@@ -350,7 +353,7 @@ def fit(self, Y, T, X, Z, *, inference=None):
350353
351354 ll = mog_loss_model (n_components , d_t )([pi , mu , sig , t_in ])
352355
353- model = Model ([z_in , x_in , t_in ], [ll ])
356+ model = Model ([z_in , x_in , t_in ], [])
354357 model .add_loss (L .Lambda (K .mean )(ll ))
355358 model .compile (self ._optimizer )
356359 # TODO: do we need to give the user more control over other arguments to fit?
@@ -365,7 +368,7 @@ def fit(self, Y, T, X, Z, *, inference=None):
365368 self ._n_samples , self ._use_upper_bound_loss , self ._n_gradient_samples )
366369
367370 rl = lm ([z_in , x_in , y_in ])
368- response_model = Model ([z_in , x_in , y_in ], [rl ])
371+ response_model = Model ([z_in , x_in , y_in ], [])
369372 response_model .add_loss (L .Lambda (K .mean )(rl ))
370373 response_model .compile (self ._optimizer )
371374 # TODO: do we need to give the user more control over other arguments to fit?
0 commit comments