@@ -32,7 +32,27 @@ def test_stop_grad(self):
3232 model = keras .Model ([x_input , y_input , z_input ], [loss ])
3333 model .add_loss (K .mean (loss ))
3434 model .compile ('nadam' )
35- model .fit ([np .array ([[1 ]]), np .array ([[2 ]]), np .array ([[0 ]])], [])
35+ model .fit ([np .array ([[1 ]]), np .array ([[2 ]]), np .array ([[0 ]])])
36+
37+ def test_mog_loss (self ):
38+ inputs = [keras .layers .Input (shape = s ) for s in [(3 ,), (3 , 2 ), (3 ,), (2 ,)]]
39+ ll_model = keras .engine .Model (inputs , mog_loss_model (3 , 2 )(inputs ))
40+
41+ for n in range (10 ):
42+ ps = - np .log (np .random .uniform (size = (3 ,)))
43+ pi = ps / np .sum (ps )
44+ mu = np .random .normal (size = (3 , 2 ))
45+ sig = np .exp (np .random .normal (size = 3 ,))
46+ t = np .random .normal (size = (2 ,))
47+
48+ pred = ll_model .predict ([pi .reshape (1 , 3 ), mu .reshape (1 , 3 , 2 ), sig .reshape (1 , 3 ), t .reshape (1 , 2 )])
49+
50+ # LL = C - log(sum(pi_i/sig^d * exp(-d2/(2*sig^2))))
51+ d = mu - t .reshape (- 1 , 2 )
52+ d2 = np .sum (d * d , axis = - 1 )
53+ ll = - np .log (np .sum (pi / (sig * sig ) * np .exp (- d2 / (2 * sig * sig )), axis = 0 ))
54+
55+ assert np .allclose (ll , pred [0 ])
3656
3757 @pytest .mark .slow
3858 def test_deepiv_shape (self ):
@@ -500,7 +520,7 @@ def norm(lr):
500520 model = keras .engine .Model ([x_input , t_input ], [ll ])
501521 model .add_loss (K .mean (ll ))
502522 model .compile ('nadam' )
503- model .fit ([x , t ], [], epochs = 5 )
523+ model .fit ([x , t ], epochs = 5 )
504524
505525 # For some reason this doesn't work at all when run against the CNTK backend...
506526 # model.compile('nadam', loss=lambda _,l:l)
@@ -559,7 +579,7 @@ def sample(n):
559579 model = keras .engine .Model ([x_input , t_input ], [ll ])
560580 model .add_loss (K .mean (ll ))
561581 model .compile ('nadam' )
562- model .fit ([x , t ], [], epochs = 100 )
582+ model .fit ([x , t ], epochs = 100 )
563583
564584 model2 = keras .engine .Model ([x_input ], [pi , mu , sig ])
565585 import matplotlib
0 commit comments