Skip to content

Commit 107678b

Browse files
jeffkbkimmeta-codesync[bot]
authored andcommitted
Foundation (#3423)
Summary: Pull Request resolved: #3423 This diff adds the basic building blocks for a zero overhead RecMetrics implementation. Follow up patches will contain integration with users of torchrec. One of the main pain points of using RecMetricModule is that metric updates and computes are done synchronously. In training jobs, there has been cases where metric updates take +20% of a training iteration. Metric computations, although less frequent, can takes over a couple of seconds. CPUOffloadedRecMetricModule aims to perform all metric updates/computes asynchronously, completely removing them from the critical path. This patch adds three classes: - MetricStateSnapshot: Encapsulation of metric state tensors that will be used to load into comms module for all gathers. - MetricUpdateJob/MetricComputeJob/SynchronizationMarker: classes to be enqueued to metric and compute queues in the future - PercentileLogger: utility class to log percentiles of queue sizes, queue times Reviewed By: iamzainhuda Differential Revision: D83173329 fbshipit-source-id: 9b1760962a361fcfb824e72c8ee4721c0012d8d8
1 parent d958b4d commit 107678b

File tree

6 files changed

+944
-1
lines changed

6 files changed

+944
-1
lines changed
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import concurrent
11+
from typing import Any, Dict
12+
13+
import torch
14+
from torchrec.metrics.metric_module import MetricValue
15+
from torchrec.metrics.metric_state_snapshot import MetricStateSnapshot
16+
17+
18+
class MetricUpdateJob:
19+
"""
20+
Encapsulates metric update job for CPU processing:
21+
update each metric state tensors with intermediate model outputs
22+
"""
23+
24+
__slots__ = ["model_out", "transfer_completed_event", "kwargs"]
25+
26+
def __init__(
27+
self,
28+
model_out: Dict[str, torch.Tensor],
29+
transfer_completed_event: torch.cuda.Event,
30+
kwargs: Dict[str, Any],
31+
) -> None:
32+
"""
33+
Args:
34+
model_out: intermediate model outputs to be used for metric updates
35+
transfer_completed_event: cuda event to track when the transfer to CPU is completed
36+
kwargs: additional arguments from the trainer platform
37+
"""
38+
39+
self.model_out: Dict[str, torch.Tensor] = model_out
40+
self.transfer_completed_event: torch.cuda.Event = transfer_completed_event
41+
self.kwargs: Dict[str, Any] = kwargs
42+
43+
44+
class MetricComputeJob:
45+
"""
46+
Encapsulates metric compute job for CPU processing: perform an
47+
all gather across ranks, compute metrics, and return the result to be
48+
published.
49+
"""
50+
51+
__slots__ = ["future", "metric_state_snapshot"]
52+
53+
def __init__(
54+
self,
55+
future: concurrent.futures.Future[Dict[str, MetricValue]],
56+
metric_state_snapshot: MetricStateSnapshot,
57+
) -> None:
58+
"""
59+
Args:
60+
future: future to set the result of the compute job. Contains the computed metrics.
61+
metric_state_snapshot: snapshot of metric state tensors across all metrics types.
62+
"""
63+
self.future: concurrent.futures.Future[Dict[str, MetricValue]] = future
64+
self.metric_state_snapshot: MetricStateSnapshot = metric_state_snapshot
65+
66+
67+
class SynchronizationMarker:
68+
"""
69+
Represents the synchronization marker that is stored in the update queue. This is the point
70+
we want to synchronize across all ranks to compute metrics.
71+
When processed, this marker will convert to a MetricComputeJob in the compute queue.
72+
73+
This separation of synchronization marker and compute job is so that the metric compute job
74+
accurately includes all of the metric jobs that were scheduled before it.
75+
"""
76+
77+
__slots__ = "future"
78+
79+
def __init__(
80+
self,
81+
future: concurrent.futures.Future[Dict[str, MetricValue]],
82+
) -> None:
83+
"""
84+
Args:
85+
future: future to set the result of the compute job. Passed to the MetricComputeJob.
86+
"""
87+
self.future: concurrent.futures.Future[Dict[str, MetricValue]] = future

torchrec/metrics/metric_module.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
#!/usr/bin/env python3
1111

1212
import abc
13+
import concurrent
1314
import logging
1415
import time
1516
from collections import defaultdict
@@ -56,7 +57,7 @@
5657
from torchrec.metrics.precision import PrecisionMetric
5758
from torchrec.metrics.precision_session import PrecisionSessionMetric
5859
from torchrec.metrics.rauc import RAUCMetric
59-
from torchrec.metrics.rec_metric import RecMetric, RecMetricList
60+
from torchrec.metrics.rec_metric import RecMetric, RecMetricException, RecMetricList
6061
from torchrec.metrics.recall import RecallMetric
6162
from torchrec.metrics.recall_session import RecallSessionMetric
6263
from torchrec.metrics.scalar import ScalarMetric
@@ -486,6 +487,14 @@ def load_pre_compute_states(
486487
for name, buf in self.throughput_metric.named_buffers(): # pyre-ignore[16]
487488
buf.copy_(states[name])
488489

490+
def shutdown(self) -> None:
491+
logger.info("Initiating graceful shutdown...")
492+
493+
def async_compute(
494+
self, future: concurrent.futures.Future[Dict[str, MetricValue]]
495+
) -> None:
496+
raise RecMetricException("async_compute is not supported in RecMetricModule")
497+
489498

490499
def _generate_rec_metrics(
491500
metrics_config: MetricsConfig,
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
# pyre-strict
9+
10+
import copy
11+
from typing import Any, cast, Dict, Optional
12+
13+
from torch import nn
14+
15+
from torchrec.metrics.rec_metric import (
16+
RecComputeMode,
17+
RecMetric,
18+
RecMetricComputation,
19+
RecMetricList,
20+
)
21+
from torchrec.metrics.throughput import ThroughputMetric
22+
23+
24+
class MetricStateSnapshot:
25+
"""
26+
Encapsulates both rec metrics reduced states and throughput metric snapshots
27+
for thread-safe CPU offloaded metric computation (updates and computes).
28+
"""
29+
30+
def __init__(
31+
self,
32+
metric_states: Dict[str, Any],
33+
throughput_metric: Optional[ThroughputMetric],
34+
) -> None:
35+
"""
36+
Args:
37+
metric_states (Dict[str, Any]): Reduced states from rec metrics
38+
throughput_metric (Optional[ThroughputMetric]): Deep copy of throughput metric
39+
"""
40+
self.metric_states = metric_states
41+
self.throughput_metric = throughput_metric
42+
43+
@classmethod
44+
def from_metrics(
45+
cls,
46+
rec_metrics: RecMetricList,
47+
throughput_metric: Optional[ThroughputMetric] = None,
48+
) -> "MetricStateSnapshot":
49+
"""
50+
Generate a MetricStateSnapshot before performing an all gather. This provides a consistent
51+
view of the local metric states without accessing the original references.
52+
53+
Apply reductions BEFORE queuing to reduce memory footprint. For instance, AUC holds a list of
54+
tensors which can be reduced to a list of a single tensor. Only reduce lists for
55+
fused mode compatibility.
56+
"""
57+
reduced_states: Dict[str, Any] = {}
58+
59+
for metric in rec_metrics.rec_metrics:
60+
metric = cast(RecMetric, metric)
61+
compute_mode = metric._compute_mode
62+
if (
63+
compute_mode == RecComputeMode.FUSED_TASKS_COMPUTATION
64+
or compute_mode == RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION
65+
):
66+
computation = metric._metrics_computations[0]
67+
_load_into_reduced_states(
68+
compute_mode.name, computation, reduced_states
69+
)
70+
else:
71+
for task, computation in zip(
72+
metric._tasks, metric._metrics_computations
73+
):
74+
_load_into_reduced_states(task.name, computation, reduced_states)
75+
76+
# Snapshot throughput metric
77+
throughput_snapshot = None
78+
if throughput_metric:
79+
throughput_snapshot = copy.deepcopy(throughput_metric)
80+
81+
return cls(
82+
metric_states=reduced_states,
83+
throughput_metric=throughput_snapshot,
84+
)
85+
86+
87+
def _load_into_reduced_states(
88+
prefix: str,
89+
computation: nn.Module,
90+
reduced_states: Dict[str, Any],
91+
) -> None:
92+
"""
93+
Load the reduced states into the reduced_states dict.
94+
95+
Args:
96+
prefix (str): prefix for the metric computation
97+
computation (nn.Module): metric computation
98+
reduced_states (Dict[str, Any]): reduced states dict to load into
99+
"""
100+
computation = cast(RecMetricComputation, computation)
101+
computation_name = f"{prefix}_{computation.__class__.__name__}"
102+
103+
for attr_name in computation._reductions:
104+
cache_key = f"{computation_name}_{attr_name}"
105+
original_value = getattr(computation, attr_name)
106+
reduction_fn = computation._reductions[attr_name]
107+
if callable(reduction_fn) and isinstance(original_value, list):
108+
reduced_value = reduction_fn(original_value)
109+
else:
110+
reduced_value = original_value
111+
112+
reduced_states[cache_key] = reduced_value

0 commit comments

Comments
 (0)