@@ -722,6 +722,88 @@ def metric_module_gather_state(
722
722
metric_module .shutdown ()
723
723
724
724
725
+ class MetricsConfigPostInitTest (unittest .TestCase ):
726
+ """Test class for MetricsConfig._post_init() validation functionality."""
727
+
728
+ def test_post_init_valid_rec_task_indices (self ) -> None :
729
+ """Test that _post_init() passes when rec_task_indices are valid."""
730
+ # Setup: create rec_tasks and valid indices
731
+ task1 = RecTaskInfo (name = "task1" , label_name = "label1" , prediction_name = "pred1" )
732
+ task2 = RecTaskInfo (name = "task2" , label_name = "label2" , prediction_name = "pred2" )
733
+ rec_tasks = [task1 , task2 ]
734
+
735
+ # Execute: create MetricsConfig with valid rec_task_indices
736
+ config = MetricsConfig (
737
+ rec_tasks = rec_tasks ,
738
+ rec_metrics = {
739
+ RecMetricEnum .AUC : RecMetricDef (rec_task_indices = [0 , 1 ]),
740
+ RecMetricEnum .NE : RecMetricDef (rec_task_indices = [0 ]),
741
+ },
742
+ )
743
+
744
+ # Assert: config should be created successfully without raising an exception
745
+ self .assertEqual (len (config .rec_tasks ), 2 )
746
+ self .assertEqual (len (config .rec_metrics ), 2 )
747
+
748
+ def test_post_init_empty_rec_task_indices (self ) -> None :
749
+ """Test that _post_init() passes when rec_task_indices is empty."""
750
+ # Setup: create rec_tasks but use empty indices
751
+ task = RecTaskInfo (name = "task" , label_name = "label" , prediction_name = "pred" )
752
+ rec_tasks = [task ]
753
+
754
+ # Execute: create MetricsConfig with empty rec_task_indices
755
+ config = MetricsConfig (
756
+ rec_tasks = rec_tasks ,
757
+ rec_metrics = {
758
+ RecMetricEnum .AUC : RecMetricDef (rec_task_indices = []),
759
+ },
760
+ )
761
+
762
+ # Assert: config should be created successfully with empty indices
763
+ self .assertEqual (len (config .rec_tasks ), 1 )
764
+ self .assertEqual (config .rec_metrics [RecMetricEnum .AUC ].rec_task_indices , [])
765
+
766
+ def test_post_init_raises_when_rec_tasks_is_none (self ) -> None :
767
+ """Test that _post_init() raises ValueError when rec_tasks is None but rec_task_indices is specified."""
768
+ # Setup: prepare to create config with None rec_tasks but specified indices
769
+
770
+ # Execute & Assert: should raise ValueError about rec_tasks being None
771
+ with self .assertRaises (ValueError ) as context :
772
+ config = MetricsConfig (
773
+ rec_tasks = None , # pyre-ignore[6]: Intentionally passing None for testing
774
+ rec_metrics = {
775
+ RecMetricEnum .AUC : RecMetricDef (rec_task_indices = [0 ]),
776
+ },
777
+ )
778
+
779
+ error_message = str (context .exception )
780
+ self .assertIn ("rec_task_indices [0] is specified" , error_message )
781
+ self .assertIn ("but rec_tasks is None" , error_message )
782
+ self .assertIn ("for metric auc" , error_message )
783
+
784
+ def test_post_init_raises_when_rec_task_index_out_of_range (self ) -> None :
785
+ """Test that _post_init() raises ValueError when rec_task_index is out of range."""
786
+ # Setup: create single rec_task but try to access index 1
787
+ task = RecTaskInfo (name = "task" , label_name = "label" , prediction_name = "pred" )
788
+ rec_tasks = [task ]
789
+
790
+ # Execute & Assert: should raise ValueError about index out of range
791
+ with self .assertRaises (ValueError ) as context :
792
+ config = MetricsConfig (
793
+ rec_tasks = rec_tasks ,
794
+ rec_metrics = {
795
+ RecMetricEnum .NE : RecMetricDef (
796
+ rec_task_indices = [1 ]
797
+ ), # Index 1 doesn't exist
798
+ },
799
+ )
800
+
801
+ error_message = str (context .exception )
802
+ self .assertIn ("rec_task_indices 1 is out of range" , error_message )
803
+ self .assertIn ("of 1 tasks" , error_message )
804
+ self .assertIn ("for metric ne" , error_message )
805
+
806
+
725
807
@skip_if_asan_class
726
808
class MetricModuleDistributedTest (MultiProcessTestBase ):
727
809
0 commit comments