|
10 | 10 |
|
11 | 11 | import unittest
|
12 | 12 | from functools import partial, update_wrapper
|
13 |
| -from typing import Callable, Dict, List, Optional, Tuple, Type |
| 13 | +from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Type, Union |
| 14 | +from collections import OrderedDict |
| 15 | +from unittest.mock import Mock, patch |
14 | 16 |
|
15 | 17 | import torch
|
16 | 18 | import torch.distributed as dist
|
17 |
| -from torchrec.metrics.metrics_config import DefaultTaskInfo |
| 19 | +from torch import Tensor |
| 20 | +from torchrec.metrics.metrics_config import BatchSizeStage, DefaultTaskInfo |
18 | 21 | from torchrec.metrics.model_utils import parse_task_model_outputs
|
19 | 22 | from torchrec.metrics.rec_metric import (
|
20 | 23 | RecComputeMode,
|
@@ -159,6 +162,10 @@ def compute(
|
159 | 162 |
|
160 | 163 |
|
161 | 164 | class TowerQPSMetricTest(unittest.TestCase):
|
| 165 | + def setUp(self) -> None: |
| 166 | + self.world_size = 1 |
| 167 | + self.batch_size = 256 |
| 168 | + |
162 | 169 | target_clazz: Type[RecMetric] = TowerQPSMetric
|
163 | 170 | task_names: str = "qps"
|
164 | 171 |
|
@@ -377,3 +384,206 @@ def test_tower_qps_update_with_invalid_tensors(self) -> None:
|
377 | 384 | "key_2": torch.rand(batch_size),
|
378 | 385 | },
|
379 | 386 | )
|
| 387 | + |
| 388 | + @patch("torchrec.metrics.tower_qps.time.monotonic") |
| 389 | + def test_batch_size_schedule(self, time_mock: Mock) -> None: |
| 390 | + |
| 391 | + def _gen_data_with_batch_size( |
| 392 | + batch_size: int, |
| 393 | + ) -> Dict[str, Union[Dict[str, Tensor], Tensor]]: |
| 394 | + return { |
| 395 | + "labels": { |
| 396 | + "t1": torch.rand(batch_size), |
| 397 | + "t2": torch.rand(batch_size), |
| 398 | + "t3": torch.rand(batch_size), |
| 399 | + }, |
| 400 | + "predictions": torch.ones(batch_size), |
| 401 | + "weights": torch.rand(batch_size), |
| 402 | + } |
| 403 | + |
| 404 | + batch_size_stages = [BatchSizeStage(256, 1), BatchSizeStage(512, None)] |
| 405 | + time_mock.return_value = 1 |
| 406 | + batch_size = 256 |
| 407 | + task_names = ["t1", "t2", "t3"] |
| 408 | + tasks = gen_test_tasks(task_names) |
| 409 | + metric = TowerQPSMetric( |
| 410 | + my_rank=0, |
| 411 | + tasks=tasks, |
| 412 | + batch_size=batch_size, |
| 413 | + world_size=1, |
| 414 | + window_size=1000, |
| 415 | + batch_size_stages=batch_size_stages, |
| 416 | + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, |
| 417 | + ) |
| 418 | + |
| 419 | + data = _gen_data_with_batch_size(batch_size_stages[0].batch_size) |
| 420 | + metric.update(**data) # pyre-ignore[6] |
| 421 | + |
| 422 | + self.assertEqual( |
| 423 | + metric.compute(), |
| 424 | + { |
| 425 | + "qps-t1|lifetime_qps": 0, |
| 426 | + "qps-t2|lifetime_qps": 0, |
| 427 | + "qps-t3|lifetime_qps": 0, |
| 428 | + "qps-t1|window_qps": 0, |
| 429 | + "qps-t2|window_qps": 0, |
| 430 | + "qps-t3|window_qps": 0, |
| 431 | + "qps-t1|total_examples": 256, |
| 432 | + "qps-t2|total_examples": 256, |
| 433 | + "qps-t3|total_examples": 256, |
| 434 | + }, |
| 435 | + ) |
| 436 | + |
| 437 | + data = _gen_data_with_batch_size(batch_size_stages[1].batch_size) |
| 438 | + metric.update(**data) # pyre-ignore[6] |
| 439 | + |
| 440 | + self.assertEqual( |
| 441 | + metric.compute(), |
| 442 | + { |
| 443 | + "qps-t1|lifetime_qps": 0, |
| 444 | + "qps-t2|lifetime_qps": 0, |
| 445 | + "qps-t3|lifetime_qps": 0, |
| 446 | + "qps-t1|window_qps": 0, |
| 447 | + "qps-t2|window_qps": 0, |
| 448 | + "qps-t3|window_qps": 0, |
| 449 | + "qps-t1|total_examples": 768, |
| 450 | + "qps-t2|total_examples": 768, |
| 451 | + "qps-t3|total_examples": 768, |
| 452 | + }, |
| 453 | + ) |
| 454 | + |
| 455 | + def test_num_batch_without_batch_size_stages(self) -> None: |
| 456 | + task_names = ["t1", "t2", "t3"] |
| 457 | + tasks = gen_test_tasks(task_names) |
| 458 | + metric = TowerQPSMetric( |
| 459 | + my_rank=0, |
| 460 | + tasks=tasks, |
| 461 | + batch_size=self.batch_size, |
| 462 | + world_size=self.world_size, |
| 463 | + window_size=1000, |
| 464 | + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, |
| 465 | + ) |
| 466 | + |
| 467 | + self.assertFalse(hasattr(metric, "num_batch")) |
| 468 | + |
| 469 | + metric.update( |
| 470 | + labels={ |
| 471 | + "t1": torch.rand(self.batch_size), |
| 472 | + "t2": torch.rand(self.batch_size), |
| 473 | + "t3": torch.rand(self.batch_size), |
| 474 | + }, |
| 475 | + predictions=torch.ones(self.batch_size), |
| 476 | + weights=torch.rand(self.batch_size), |
| 477 | + ) |
| 478 | + state_dict: Dict[str, Any] = metric.state_dict() |
| 479 | + self.assertNotIn("num_batch", state_dict) |
| 480 | + |
| 481 | + def test_state_dict_load_module_lifecycle(self) -> None: |
| 482 | + task_names = ["t1", "t2", "t3"] |
| 483 | + tasks = gen_test_tasks(task_names) |
| 484 | + metric = TowerQPSMetric( |
| 485 | + my_rank=0, |
| 486 | + tasks=tasks, |
| 487 | + batch_size=self.batch_size, |
| 488 | + world_size=self.world_size, |
| 489 | + window_size=1000, |
| 490 | + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, |
| 491 | + batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)], |
| 492 | + ) |
| 493 | + |
| 494 | + self.assertTrue(hasattr(metric, "_num_batch")) |
| 495 | + |
| 496 | + metric.update( |
| 497 | + labels={ |
| 498 | + "t1": torch.rand(self.batch_size), |
| 499 | + "t2": torch.rand(self.batch_size), |
| 500 | + "t3": torch.rand(self.batch_size), |
| 501 | + }, |
| 502 | + predictions=torch.ones(self.batch_size), |
| 503 | + weights=torch.rand(self.batch_size), |
| 504 | + ) |
| 505 | + self.assertEqual(metric._num_batch, 1) |
| 506 | + state_dict = metric.state_dict() |
| 507 | + self.assertIn("num_batch", state_dict) |
| 508 | + self.assertEqual(state_dict["num_batch"].item(), metric._num_batch) |
| 509 | + |
| 510 | + new_metric = TowerQPSMetric( |
| 511 | + my_rank=0, |
| 512 | + tasks=tasks, |
| 513 | + batch_size=self.batch_size, |
| 514 | + world_size=self.world_size, |
| 515 | + window_size=1000, |
| 516 | + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, |
| 517 | + batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)], |
| 518 | + ) |
| 519 | + self.assertEqual(new_metric._num_batch, 0) |
| 520 | + new_metric.load_state_dict(state_dict) |
| 521 | + self.assertEqual(new_metric._num_batch, 1) |
| 522 | + |
| 523 | + state_dict = new_metric.state_dict() |
| 524 | + self.assertIn("num_batch", state_dict) |
| 525 | + self.assertEqual(state_dict["num_batch"].item(), new_metric._num_batch) |
| 526 | + |
| 527 | + def test_state_dict_hook_adds_key(self) -> None: |
| 528 | + task_names = ["t1", "t2", "t3"] |
| 529 | + tasks = gen_test_tasks(task_names) |
| 530 | + metric = TowerQPSMetric( |
| 531 | + my_rank=0, |
| 532 | + tasks=tasks, |
| 533 | + batch_size=self.batch_size, |
| 534 | + world_size=self.world_size, |
| 535 | + window_size=1000, |
| 536 | + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, |
| 537 | + batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(256, None)], |
| 538 | + ) |
| 539 | + |
| 540 | + for _ in range(5): |
| 541 | + metric.update( |
| 542 | + labels={ |
| 543 | + "t1": torch.rand(self.batch_size), |
| 544 | + "t2": torch.rand(self.batch_size), |
| 545 | + "t3": torch.rand(self.batch_size), |
| 546 | + }, |
| 547 | + predictions=torch.ones(self.batch_size), |
| 548 | + weights=torch.rand(self.batch_size), |
| 549 | + ) |
| 550 | + state_dict: OrderedDict[str, torch.Tensor] = OrderedDict() |
| 551 | + prefix: str = "test_prefix_" |
| 552 | + metric.state_dict_hook(metric, state_dict, prefix, {}) |
| 553 | + self.assertIn(f"{prefix}num_batch", state_dict) |
| 554 | + self.assertEqual(state_dict[f"{prefix}num_batch"].item(), 5) |
| 555 | + |
| 556 | + def test_state_dict_hook_no_batch_size_stages(self) -> None: |
| 557 | + task_names = ["t1", "t2", "t3"] |
| 558 | + tasks = gen_test_tasks(task_names) |
| 559 | + metric = TowerQPSMetric( |
| 560 | + my_rank=0, |
| 561 | + tasks=tasks, |
| 562 | + batch_size=self.batch_size, |
| 563 | + world_size=self.world_size, |
| 564 | + window_size=1000, |
| 565 | + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, |
| 566 | + batch_size_stages=None, |
| 567 | + ) |
| 568 | + state_dict: OrderedDict[str, torch.Tensor] = OrderedDict() |
| 569 | + prefix: str = "test_prefix_" |
| 570 | + metric.state_dict_hook(metric, state_dict, prefix, {}) |
| 571 | + self.assertNotIn(f"{prefix}num_batch", state_dict) |
| 572 | + |
| 573 | + def test_load_state_dict_hook_restores_value(self) -> None: |
| 574 | + task_names = ["t1", "t2", "t3"] |
| 575 | + tasks = gen_test_tasks(task_names) |
| 576 | + metric = TowerQPSMetric( |
| 577 | + my_rank=0, |
| 578 | + tasks=tasks, |
| 579 | + batch_size=self.batch_size, |
| 580 | + world_size=self.world_size, |
| 581 | + window_size=1000, |
| 582 | + compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, |
| 583 | + batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)], |
| 584 | + ) |
| 585 | + state_dict: OrderedDict[str, torch.Tensor] = OrderedDict() |
| 586 | + prefix: str = "test_prefix_" |
| 587 | + state_dict[f"{prefix}num_batch"] = torch.tensor(10, dtype=torch.long) |
| 588 | + metric.load_state_dict_hook(state_dict, prefix, {}, True, [], [], []) |
| 589 | + self.assertEqual(metric._num_batch, 10) |
0 commit comments