@@ -461,3 +461,171 @@ def test_axis(self):
461461 loss = cosine_obj (self .y_true , self .y_pred )
462462 expected_loss = - np .mean (self .expected_loss )
463463 self .assertAlmostEqual (loss , expected_loss , 3 )
464+
465+
466+ class KLDivergenceTest (testing .TestCase ):
467+ def setup (self ):
468+ self .y_pred = np .asarray (
469+ [0.4 , 0.9 , 0.12 , 0.36 , 0.3 , 0.4 ], dtype = np .float32
470+ ).reshape ((2 , 3 ))
471+ self .y_true = np .asarray (
472+ [0.5 , 0.8 , 0.12 , 0.7 , 0.43 , 0.8 ], dtype = np .float32
473+ ).reshape ((2 , 3 ))
474+
475+ self .batch_size = 2
476+ self .expected_losses = np .multiply (
477+ self .y_true , np .log (self .y_true / self .y_pred )
478+ )
479+
480+ def test_config (self ):
481+ k_obj = losses .KLDivergence (reduction = "sum" , name = "kld" )
482+ self .assertEqual (k_obj .name , "kld" )
483+ self .assertEqual (k_obj .reduction , "sum" )
484+
485+ def test_unweighted (self ):
486+ self .setup ()
487+ k_obj = losses .KLDivergence ()
488+
489+ loss = k_obj (self .y_true , self .y_pred )
490+ expected_loss = np .sum (self .expected_losses ) / self .batch_size
491+ self .assertAlmostEqual (loss , expected_loss , 3 )
492+
493+ def test_scalar_weighted (self ):
494+ self .setup ()
495+ k_obj = losses .KLDivergence ()
496+ sample_weight = 2.3
497+
498+ loss = k_obj (self .y_true , self .y_pred , sample_weight = sample_weight )
499+ expected_loss = (
500+ sample_weight * np .sum (self .expected_losses ) / self .batch_size
501+ )
502+ self .assertAlmostEqual (loss , expected_loss , 3 )
503+
504+ # Verify we get the same output when the same input is given
505+ loss_2 = k_obj (self .y_true , self .y_pred , sample_weight = sample_weight )
506+ self .assertAlmostEqual (loss , loss_2 , 3 )
507+
508+ def test_sample_weighted (self ):
509+ self .setup ()
510+ k_obj = losses .KLDivergence ()
511+ sample_weight = np .asarray ([1.2 , 3.4 ], dtype = np .float32 ).reshape ((2 , 1 ))
512+ loss = k_obj (self .y_true , self .y_pred , sample_weight = sample_weight )
513+
514+ expected_loss = np .multiply (
515+ self .expected_losses ,
516+ np .asarray (
517+ [1.2 , 1.2 , 1.2 , 3.4 , 3.4 , 3.4 ], dtype = np .float32
518+ ).reshape (2 , 3 ),
519+ )
520+ expected_loss = np .sum (expected_loss ) / self .batch_size
521+ self .assertAlmostEqual (loss , expected_loss , 3 )
522+
523+ def test_timestep_weighted (self ):
524+ self .setup ()
525+ k_obj = losses .KLDivergence ()
526+ y_true = self .y_true .reshape (2 , 3 , 1 )
527+ y_pred = self .y_pred .reshape (2 , 3 , 1 )
528+ sample_weight = np .asarray ([3 , 6 , 5 , 0 , 4 , 2 ]).reshape (2 , 3 )
529+ expected_losses = np .sum (
530+ np .multiply (y_true , np .log (y_true / y_pred )), axis = - 1
531+ )
532+ loss = k_obj (y_true , y_pred , sample_weight = sample_weight )
533+
534+ num_timesteps = 3
535+ expected_loss = np .sum (expected_losses * sample_weight ) / (
536+ self .batch_size * num_timesteps
537+ )
538+ self .assertAlmostEqual (loss , expected_loss , 3 )
539+
540+ def test_zero_weighted (self ):
541+ self .setup ()
542+ k_obj = losses .KLDivergence ()
543+ loss = k_obj (self .y_true , self .y_pred , sample_weight = 0 )
544+ self .assertAlmostEqual (loss , 0.0 , 3 )
545+
546+
547+ class PoissonTest (testing .TestCase ):
548+ def setup (self ):
549+ self .y_pred = np .asarray ([1 , 9 , 2 , 5 , 2 , 6 ], dtype = np .float32 ).reshape (
550+ (2 , 3 )
551+ )
552+ self .y_true = np .asarray ([4 , 8 , 12 , 8 , 1 , 3 ], dtype = np .float32 ).reshape (
553+ (2 , 3 )
554+ )
555+
556+ self .batch_size = 6
557+ self .expected_losses = self .y_pred - np .multiply (
558+ self .y_true , np .log (self .y_pred )
559+ )
560+
561+ def test_config (self ):
562+ poisson_obj = losses .Poisson (reduction = "sum" , name = "poisson" )
563+ self .assertEqual (poisson_obj .name , "poisson" )
564+ self .assertEqual (poisson_obj .reduction , "sum" )
565+
566+ def test_unweighted (self ):
567+ self .setup ()
568+ poisson_obj = losses .Poisson ()
569+
570+ loss = poisson_obj (self .y_true , self .y_pred )
571+ expected_loss = np .sum (self .expected_losses ) / self .batch_size
572+ self .assertAlmostEqual (loss , expected_loss , 3 )
573+
574+ def test_scalar_weighted (self ):
575+ self .setup ()
576+ poisson_obj = losses .Poisson ()
577+ sample_weight = 2.3
578+ loss = poisson_obj (
579+ self .y_true , self .y_pred , sample_weight = sample_weight
580+ )
581+ expected_loss = (
582+ sample_weight * np .sum (self .expected_losses ) / self .batch_size
583+ )
584+ self .assertAlmostEqual (loss , expected_loss , 3 )
585+ self .assertAlmostEqual (loss , expected_loss , 3 )
586+
587+ # Verify we get the same output when the same input is given
588+ loss_2 = poisson_obj (
589+ self .y_true , self .y_pred , sample_weight = sample_weight
590+ )
591+ self .assertAlmostEqual (loss , loss_2 , 3 )
592+
593+ def test_sample_weighted (self ):
594+ self .setup ()
595+ poisson_obj = losses .Poisson ()
596+
597+ sample_weight = np .asarray ([1.2 , 3.4 ]).reshape ((2 , 1 ))
598+ loss = poisson_obj (
599+ self .y_true , self .y_pred , sample_weight = sample_weight
600+ )
601+
602+ expected_loss = np .multiply (
603+ self .expected_losses ,
604+ np .asarray ([1.2 , 1.2 , 1.2 , 3.4 , 3.4 , 3.4 ]).reshape ((2 , 3 )),
605+ )
606+ expected_loss = np .sum (expected_loss ) / self .batch_size
607+ self .assertAlmostEqual (loss , expected_loss , 3 )
608+
609+ def test_timestep_weighted (self ):
610+ self .setup ()
611+ poisson_obj = losses .Poisson ()
612+ y_true = self .y_true .reshape (2 , 3 , 1 )
613+ y_pred = self .y_pred .reshape (2 , 3 , 1 )
614+ sample_weight = np .asarray ([3 , 6 , 5 , 0 , 4 , 2 ]).reshape (2 , 3 , 1 )
615+ expected_losses = y_pred - np .multiply (y_true , np .log (y_pred ))
616+
617+ loss = poisson_obj (
618+ y_true ,
619+ y_pred ,
620+ sample_weight = np .asarray (sample_weight ).reshape ((2 , 3 )),
621+ )
622+ expected_loss = (
623+ np .sum (expected_losses * sample_weight ) / self .batch_size
624+ )
625+ self .assertAlmostEqual (loss , expected_loss , 3 )
626+
627+ def test_zero_weighted (self ):
628+ self .setup ()
629+ poisson_obj = losses .Poisson ()
630+ loss = poisson_obj (self .y_true , self .y_pred , sample_weight = 0 )
631+ self .assertAlmostEqual (loss , 0.0 , 3 )
0 commit comments