Skip to content

Commit d991b46

Browse files
iamzainhudameta-codesync[bot]
authored andcommitted
add variable batch size support to tower QPS (#3438)
Summary: Pull Request resolved: #3438 add variable batch size support to tower QPS - this applies under fused recmetrics task mode for tower QPS. this is because for fusion, we concat the state tensors across tasks to more efficiently compute the metric value. future todos: examine other metrics with batch size dependency and move batch size scheduling to module level (recmetricmodule/recmetric), this way we can pass batch_size as a parameter in update() according to schedule vs. setting up on a per metric basis. Reviewed By: irobert0126, AKhazane Differential Revision: D83700799 fbshipit-source-id: a9e36c8485c4fe893525fab5213219e6d06df60b
1 parent d3722b6 commit d991b46

File tree

2 files changed

+283
-4
lines changed

2 files changed

+283
-4
lines changed

torchrec/metrics/tests/test_tower_qps.py

Lines changed: 212 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,14 @@
1010

1111
import unittest
1212
from functools import partial, update_wrapper
13-
from typing import Callable, Dict, List, Optional, Tuple, Type
13+
from typing import Any, Callable, Dict, List, Optional, OrderedDict, Tuple, Type, Union
14+
from collections import OrderedDict
15+
from unittest.mock import Mock, patch
1416

1517
import torch
1618
import torch.distributed as dist
17-
from torchrec.metrics.metrics_config import DefaultTaskInfo
19+
from torch import Tensor
20+
from torchrec.metrics.metrics_config import BatchSizeStage, DefaultTaskInfo
1821
from torchrec.metrics.model_utils import parse_task_model_outputs
1922
from torchrec.metrics.rec_metric import (
2023
RecComputeMode,
@@ -159,6 +162,10 @@ def compute(
159162

160163

161164
class TowerQPSMetricTest(unittest.TestCase):
165+
def setUp(self) -> None:
166+
self.world_size = 1
167+
self.batch_size = 256
168+
162169
target_clazz: Type[RecMetric] = TowerQPSMetric
163170
task_names: str = "qps"
164171

@@ -377,3 +384,206 @@ def test_tower_qps_update_with_invalid_tensors(self) -> None:
377384
"key_2": torch.rand(batch_size),
378385
},
379386
)
387+
388+
@patch("torchrec.metrics.tower_qps.time.monotonic")
389+
def test_batch_size_schedule(self, time_mock: Mock) -> None:
390+
391+
def _gen_data_with_batch_size(
392+
batch_size: int,
393+
) -> Dict[str, Union[Dict[str, Tensor], Tensor]]:
394+
return {
395+
"labels": {
396+
"t1": torch.rand(batch_size),
397+
"t2": torch.rand(batch_size),
398+
"t3": torch.rand(batch_size),
399+
},
400+
"predictions": torch.ones(batch_size),
401+
"weights": torch.rand(batch_size),
402+
}
403+
404+
batch_size_stages = [BatchSizeStage(256, 1), BatchSizeStage(512, None)]
405+
time_mock.return_value = 1
406+
batch_size = 256
407+
task_names = ["t1", "t2", "t3"]
408+
tasks = gen_test_tasks(task_names)
409+
metric = TowerQPSMetric(
410+
my_rank=0,
411+
tasks=tasks,
412+
batch_size=batch_size,
413+
world_size=1,
414+
window_size=1000,
415+
batch_size_stages=batch_size_stages,
416+
compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
417+
)
418+
419+
data = _gen_data_with_batch_size(batch_size_stages[0].batch_size)
420+
metric.update(**data) # pyre-ignore[6]
421+
422+
self.assertEqual(
423+
metric.compute(),
424+
{
425+
"qps-t1|lifetime_qps": 0,
426+
"qps-t2|lifetime_qps": 0,
427+
"qps-t3|lifetime_qps": 0,
428+
"qps-t1|window_qps": 0,
429+
"qps-t2|window_qps": 0,
430+
"qps-t3|window_qps": 0,
431+
"qps-t1|total_examples": 256,
432+
"qps-t2|total_examples": 256,
433+
"qps-t3|total_examples": 256,
434+
},
435+
)
436+
437+
data = _gen_data_with_batch_size(batch_size_stages[1].batch_size)
438+
metric.update(**data) # pyre-ignore[6]
439+
440+
self.assertEqual(
441+
metric.compute(),
442+
{
443+
"qps-t1|lifetime_qps": 0,
444+
"qps-t2|lifetime_qps": 0,
445+
"qps-t3|lifetime_qps": 0,
446+
"qps-t1|window_qps": 0,
447+
"qps-t2|window_qps": 0,
448+
"qps-t3|window_qps": 0,
449+
"qps-t1|total_examples": 768,
450+
"qps-t2|total_examples": 768,
451+
"qps-t3|total_examples": 768,
452+
},
453+
)
454+
455+
def test_num_batch_without_batch_size_stages(self) -> None:
456+
task_names = ["t1", "t2", "t3"]
457+
tasks = gen_test_tasks(task_names)
458+
metric = TowerQPSMetric(
459+
my_rank=0,
460+
tasks=tasks,
461+
batch_size=self.batch_size,
462+
world_size=self.world_size,
463+
window_size=1000,
464+
compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
465+
)
466+
467+
self.assertFalse(hasattr(metric, "num_batch"))
468+
469+
metric.update(
470+
labels={
471+
"t1": torch.rand(self.batch_size),
472+
"t2": torch.rand(self.batch_size),
473+
"t3": torch.rand(self.batch_size),
474+
},
475+
predictions=torch.ones(self.batch_size),
476+
weights=torch.rand(self.batch_size),
477+
)
478+
state_dict: Dict[str, Any] = metric.state_dict()
479+
self.assertNotIn("num_batch", state_dict)
480+
481+
def test_state_dict_load_module_lifecycle(self) -> None:
482+
task_names = ["t1", "t2", "t3"]
483+
tasks = gen_test_tasks(task_names)
484+
metric = TowerQPSMetric(
485+
my_rank=0,
486+
tasks=tasks,
487+
batch_size=self.batch_size,
488+
world_size=self.world_size,
489+
window_size=1000,
490+
compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
491+
batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)],
492+
)
493+
494+
self.assertTrue(hasattr(metric, "_num_batch"))
495+
496+
metric.update(
497+
labels={
498+
"t1": torch.rand(self.batch_size),
499+
"t2": torch.rand(self.batch_size),
500+
"t3": torch.rand(self.batch_size),
501+
},
502+
predictions=torch.ones(self.batch_size),
503+
weights=torch.rand(self.batch_size),
504+
)
505+
self.assertEqual(metric._num_batch, 1)
506+
state_dict = metric.state_dict()
507+
self.assertIn("num_batch", state_dict)
508+
self.assertEqual(state_dict["num_batch"].item(), metric._num_batch)
509+
510+
new_metric = TowerQPSMetric(
511+
my_rank=0,
512+
tasks=tasks,
513+
batch_size=self.batch_size,
514+
world_size=self.world_size,
515+
window_size=1000,
516+
compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
517+
batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)],
518+
)
519+
self.assertEqual(new_metric._num_batch, 0)
520+
new_metric.load_state_dict(state_dict)
521+
self.assertEqual(new_metric._num_batch, 1)
522+
523+
state_dict = new_metric.state_dict()
524+
self.assertIn("num_batch", state_dict)
525+
self.assertEqual(state_dict["num_batch"].item(), new_metric._num_batch)
526+
527+
def test_state_dict_hook_adds_key(self) -> None:
528+
task_names = ["t1", "t2", "t3"]
529+
tasks = gen_test_tasks(task_names)
530+
metric = TowerQPSMetric(
531+
my_rank=0,
532+
tasks=tasks,
533+
batch_size=self.batch_size,
534+
world_size=self.world_size,
535+
window_size=1000,
536+
compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
537+
batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(256, None)],
538+
)
539+
540+
for _ in range(5):
541+
metric.update(
542+
labels={
543+
"t1": torch.rand(self.batch_size),
544+
"t2": torch.rand(self.batch_size),
545+
"t3": torch.rand(self.batch_size),
546+
},
547+
predictions=torch.ones(self.batch_size),
548+
weights=torch.rand(self.batch_size),
549+
)
550+
state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
551+
prefix: str = "test_prefix_"
552+
metric.state_dict_hook(metric, state_dict, prefix, {})
553+
self.assertIn(f"{prefix}num_batch", state_dict)
554+
self.assertEqual(state_dict[f"{prefix}num_batch"].item(), 5)
555+
556+
def test_state_dict_hook_no_batch_size_stages(self) -> None:
557+
task_names = ["t1", "t2", "t3"]
558+
tasks = gen_test_tasks(task_names)
559+
metric = TowerQPSMetric(
560+
my_rank=0,
561+
tasks=tasks,
562+
batch_size=self.batch_size,
563+
world_size=self.world_size,
564+
window_size=1000,
565+
compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
566+
batch_size_stages=None,
567+
)
568+
state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
569+
prefix: str = "test_prefix_"
570+
metric.state_dict_hook(metric, state_dict, prefix, {})
571+
self.assertNotIn(f"{prefix}num_batch", state_dict)
572+
573+
def test_load_state_dict_hook_restores_value(self) -> None:
574+
task_names = ["t1", "t2", "t3"]
575+
tasks = gen_test_tasks(task_names)
576+
metric = TowerQPSMetric(
577+
my_rank=0,
578+
tasks=tasks,
579+
batch_size=self.batch_size,
580+
world_size=self.world_size,
581+
window_size=1000,
582+
compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION,
583+
batch_size_stages=[BatchSizeStage(256, 1), BatchSizeStage(512, None)],
584+
)
585+
state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
586+
prefix: str = "test_prefix_"
587+
state_dict[f"{prefix}num_batch"] = torch.tensor(10, dtype=torch.long)
588+
metric.load_state_dict_hook(state_dict, prefix, {}, True, [], [], [])
589+
self.assertEqual(metric._num_batch, 10)

torchrec/metrics/tower_qps.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77

88
# pyre-strict
99

10+
import copy
1011
import time
11-
from typing import Any, cast, Dict, List, Optional, Type
12+
from typing import Any, cast, Dict, List, Optional, OrderedDict, Type
1213

1314
import torch
1415
import torch.distributed as dist
16+
from torch import nn
17+
from torchrec.distributed.utils import none_throws
1518

16-
from torchrec.metrics.metrics_config import RecComputeMode, RecTaskInfo
19+
from torchrec.metrics.metrics_config import BatchSizeStage, RecComputeMode, RecTaskInfo
1720
from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
1821
from torchrec.metrics.rec_metric import (
1922
MetricComputationReport,
@@ -194,6 +197,7 @@ def __init__(
194197
fused_update_limit: int = 0,
195198
process_group: Optional[dist.ProcessGroup] = None,
196199
warmup_steps: int = WARMUP_STEPS,
200+
batch_size_stages: Optional[List[BatchSizeStage]] = None,
197201
**kwargs: Any,
198202
) -> None:
199203
if fused_update_limit > 0:
@@ -213,6 +217,18 @@ def __init__(
213217
**kwargs,
214218
)
215219

220+
self._batch_size = batch_size
221+
self._world_size = world_size
222+
self._batch_size_stages: Optional[List[BatchSizeStage]] = copy.deepcopy(
223+
batch_size_stages
224+
)
225+
226+
if self._batch_size_stages is not None:
227+
self._num_batch: int = 0
228+
229+
self._register_load_state_dict_pre_hook(self.load_state_dict_hook)
230+
self.register_state_dict_post_hook(self.state_dict_hook)
231+
216232
def update(
217233
self,
218234
*,
@@ -221,6 +237,9 @@ def update(
221237
weights: Optional[RecModelOutput],
222238
**kwargs: Dict[str, Any],
223239
) -> None:
240+
if self._batch_size_stages is not None:
241+
self._num_batch += 1
242+
self._batch_size = self._get_batch_size()
224243
with torch.no_grad():
225244
if self._compute_mode in [
226245
RecComputeMode.FUSED_TASKS_COMPUTATION,
@@ -313,3 +332,53 @@ def update(
313332
labels=task_labels,
314333
weights=None,
315334
)
335+
336+
def _get_batch_size(self) -> int:
337+
if not self._batch_size_stages:
338+
return self._batch_size
339+
340+
batch_size_stages = none_throws(self._batch_size_stages)
341+
while self._batch_size_stages:
342+
stage = self._batch_size_stages[0]
343+
if stage.max_iters is None:
344+
assert len(batch_size_stages) == 1
345+
return stage.batch_size
346+
if stage.max_iters < self._num_batch:
347+
batch_size_stages.pop(0)
348+
continue
349+
return stage.batch_size
350+
raise AssertionError("Unreachable, batch_size_stages should always has 1 item")
351+
352+
@staticmethod
353+
def state_dict_hook(
354+
module: nn.Module,
355+
state_dict: OrderedDict[str, torch.Tensor],
356+
prefix: str,
357+
local_metadata: Dict[str, Any],
358+
) -> None:
359+
"""
360+
The state dict hook and load state dict hook exist to ensure we load num_batch for a metric with
361+
batch_size_stages set. The reason we do this apporach as opposted to saving num_batch as a buffer
362+
is in some cases we are accessing the value from a CPU workload where the tensors are on GPU. This
363+
incurs a device to head call, which is expensive and blocking.
364+
"""
365+
if module._batch_size_stages is not None:
366+
num_batch_key = f"{prefix}num_batch"
367+
state_dict[num_batch_key] = torch.tensor(
368+
module._num_batch, dtype=torch.long
369+
)
370+
371+
def load_state_dict_hook(
372+
self,
373+
state_dict: OrderedDict[str, torch.Tensor],
374+
prefix: str,
375+
local_metadata: Dict[str, Any],
376+
strict: bool,
377+
missing_keys: List[str],
378+
unexpected_keys: List[str],
379+
error_msgs: List[str],
380+
) -> None:
381+
key = f"{prefix}num_batch"
382+
if key in state_dict and self._batch_size_stages is not None:
383+
num_batch_tensor = state_dict.pop(key)
384+
self._num_batch = int(num_batch_tensor.item())

0 commit comments

Comments
 (0)