Skip to content

Commit 283e2f8

Browse files
nipung90meta-codesync[bot]
authored andcommitted
Add validations for rec metrics config creation to avoid out of bounds indices (#3421)
Summary: Pull Request resolved: #3421 This validation would avoid issues like: https://fb.workplace.com/groups/755371733754414/posts/794025843514383/ Reviewed By: iamzainhuda Differential Revision: D83764096 fbshipit-source-id: 3ba63a3849c974eb8ce4f2c09b3bfcf024278c43
1 parent 858d00b commit 283e2f8

File tree

2 files changed

+95
-0
lines changed

2 files changed

+95
-0
lines changed

torchrec/metrics/metrics_config.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,19 @@ class MetricsConfig:
188188
enable_pt2_compile: bool = False
189189
should_clone_update_inputs: bool = False
190190

191+
def __post_init__(self) -> None:
192+
for metric_enum, metric_def in self.rec_metrics.items():
193+
if metric_def.rec_task_indices:
194+
if self.rec_tasks is None:
195+
raise ValueError(
196+
f"rec_task_indices {metric_def.rec_task_indices} is specified, but rec_tasks is None for metric {metric_enum}"
197+
)
198+
for idx in metric_def.rec_task_indices:
199+
if idx >= len(self.rec_tasks):
200+
raise ValueError(
201+
f"rec_task_indices {idx} is out of range of {len(self.rec_tasks)} tasks for metric {metric_enum}"
202+
)
203+
191204

192205
DefaultTaskInfo = RecTaskInfo(
193206
name="DefaultTask",

torchrec/metrics/tests/test_metric_module.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -722,6 +722,88 @@ def metric_module_gather_state(
722722
metric_module.shutdown()
723723

724724

725+
class MetricsConfigPostInitTest(unittest.TestCase):
726+
"""Test class for MetricsConfig._post_init() validation functionality."""
727+
728+
def test_post_init_valid_rec_task_indices(self) -> None:
729+
"""Test that _post_init() passes when rec_task_indices are valid."""
730+
# Setup: create rec_tasks and valid indices
731+
task1 = RecTaskInfo(name="task1", label_name="label1", prediction_name="pred1")
732+
task2 = RecTaskInfo(name="task2", label_name="label2", prediction_name="pred2")
733+
rec_tasks = [task1, task2]
734+
735+
# Execute: create MetricsConfig with valid rec_task_indices
736+
config = MetricsConfig(
737+
rec_tasks=rec_tasks,
738+
rec_metrics={
739+
RecMetricEnum.AUC: RecMetricDef(rec_task_indices=[0, 1]),
740+
RecMetricEnum.NE: RecMetricDef(rec_task_indices=[0]),
741+
},
742+
)
743+
744+
# Assert: config should be created successfully without raising an exception
745+
self.assertEqual(len(config.rec_tasks), 2)
746+
self.assertEqual(len(config.rec_metrics), 2)
747+
748+
def test_post_init_empty_rec_task_indices(self) -> None:
749+
"""Test that _post_init() passes when rec_task_indices is empty."""
750+
# Setup: create rec_tasks but use empty indices
751+
task = RecTaskInfo(name="task", label_name="label", prediction_name="pred")
752+
rec_tasks = [task]
753+
754+
# Execute: create MetricsConfig with empty rec_task_indices
755+
config = MetricsConfig(
756+
rec_tasks=rec_tasks,
757+
rec_metrics={
758+
RecMetricEnum.AUC: RecMetricDef(rec_task_indices=[]),
759+
},
760+
)
761+
762+
# Assert: config should be created successfully with empty indices
763+
self.assertEqual(len(config.rec_tasks), 1)
764+
self.assertEqual(config.rec_metrics[RecMetricEnum.AUC].rec_task_indices, [])
765+
766+
def test_post_init_raises_when_rec_tasks_is_none(self) -> None:
767+
"""Test that _post_init() raises ValueError when rec_tasks is None but rec_task_indices is specified."""
768+
# Setup: prepare to create config with None rec_tasks but specified indices
769+
770+
# Execute & Assert: should raise ValueError about rec_tasks being None
771+
with self.assertRaises(ValueError) as context:
772+
config = MetricsConfig(
773+
rec_tasks=None, # pyre-ignore[6]: Intentionally passing None for testing
774+
rec_metrics={
775+
RecMetricEnum.AUC: RecMetricDef(rec_task_indices=[0]),
776+
},
777+
)
778+
779+
error_message = str(context.exception)
780+
self.assertIn("rec_task_indices [0] is specified", error_message)
781+
self.assertIn("but rec_tasks is None", error_message)
782+
self.assertIn("for metric auc", error_message)
783+
784+
def test_post_init_raises_when_rec_task_index_out_of_range(self) -> None:
785+
"""Test that _post_init() raises ValueError when rec_task_index is out of range."""
786+
# Setup: create single rec_task but try to access index 1
787+
task = RecTaskInfo(name="task", label_name="label", prediction_name="pred")
788+
rec_tasks = [task]
789+
790+
# Execute & Assert: should raise ValueError about index out of range
791+
with self.assertRaises(ValueError) as context:
792+
config = MetricsConfig(
793+
rec_tasks=rec_tasks,
794+
rec_metrics={
795+
RecMetricEnum.NE: RecMetricDef(
796+
rec_task_indices=[1]
797+
), # Index 1 doesn't exist
798+
},
799+
)
800+
801+
error_message = str(context.exception)
802+
self.assertIn("rec_task_indices 1 is out of range", error_message)
803+
self.assertIn("of 1 tasks", error_message)
804+
self.assertIn("for metric ne", error_message)
805+
806+
725807
@skip_if_asan_class
726808
class MetricModuleDistributedTest(MultiProcessTestBase):
727809

0 commit comments

Comments
 (0)