Skip to content

Commit 7f85833

Browse files
kashifyzhangcs
andauthored
Update torchtitan and train.py (#21)
* activations on CUDA offloaded * add save_for_all_ranks config * update torchtitan * update train.py * use build_loss_fn * add get_nparams_and_flops * remove unused import * Fix isort issues --------- Co-authored-by: Yu Zhang <yzhang.cs@outlook.com>
1 parent c949efe commit 7f85833

File tree

6 files changed

+107
-143
lines changed

6 files changed

+107
-143
lines changed

3rdparty/flash-linear-attention

3rdparty/torchtitan

Submodule torchtitan updated 92 files

flame/config_manager.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -466,15 +466,15 @@ def __init__(self):
466466
default="tb",
467467
help="Folder to dump TensorBoard states",
468468
)
469-
# TODO: store_true & default=True make impossible for cmd to set it to False
470469
self.parser.add_argument(
471-
"--metrics.rank_0_only",
470+
"--metrics.save_for_all_ranks",
472471
action="store_true",
473-
default=True,
472+
default=False,
474473
help="""
475-
Whether to save TensorBoard metrics only for rank 0 or for all ranks.
476-
When pipeline_parallel_degree is > 1, this option uses the 0th rank of the last stage pipeline group,
477-
which is the only stage that computes loss metrics.
474+
Whether to save TensorBoard/Wandb metrics only for rank 0 or for all ranks.
475+
When this option is False and pipeline_parallel_degree is > 1, the metrics
476+
component uses the 0th rank of the last stage pipeline group, which is the
477+
only stage that computes loss metrics.
478478
""",
479479
)
480480
self.parser.add_argument(

flame/models/activation_offloading.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,11 +151,15 @@ def pack_tensor(activation: torch.Tensor) -> int:
151151
num_bytes = get_num_bytes_tensor(activation)
152152
tensor_id = get_tensor_id()
153153

154-
# only offload hefty bois if they're activations (our heuristic for that is to
155-
# check if they're not params or buffers)!
156-
if num_bytes >= self.min_tensor_size_bytes and (
157-
not isinstance(activation, torch.nn.Parameter)
158-
and not isinstance(activation, torch.nn.Buffer)
154+
# only offload hefty bois if they're activations on CUDA (our heuristic
155+
# for that is to check if they're not params or buffers)!
156+
if (
157+
activation.is_cuda
158+
and num_bytes >= self.min_tensor_size_bytes
159+
and (
160+
not isinstance(activation, torch.nn.Parameter)
161+
and not isinstance(activation, torch.nn.Buffer)
162+
)
159163
):
160164
if self.use_streams:
161165
# First, sync back and dereference previously offloaded tensors

flame/tools/utils.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,18 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from torch import nn
78
from torchtitan.tools.logging import logger
89

910

10-
def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int:
11+
def get_nparams_and_flops(model: nn.Module, model_config, seq_len: int) -> tuple[int, int]:
12+
nparams = sum(p.numel() for p in model.parameters())
13+
nparams_embedding = sum(
14+
sum(p.numel() for p in m.parameters())
15+
for m in model.children()
16+
if isinstance(m, nn.Embedding)
17+
)
18+
1119
if hasattr(model_config, "num_heads"):
1220
num_heads = model_config.num_heads
1321
elif hasattr(model_config, "num_attention_heads"):
@@ -28,6 +36,6 @@ def get_num_flop_per_token(num_params: int, model_config, seq_len) -> int:
2836
# but recomputation should not be counted in calculating MFU (+0)
2937
# 3. each matmul performs 1 multiplication and 1 addition (*2)
3038
# 4. we follow the convention and do not account for sparsity in causal attention
31-
flop_per_token = 6 * num_params + 12 * l * h * q * t
39+
num_flops_per_token = 6 * (nparams - nparams_embedding) + 12 * l * h * q * t
3240

33-
return flop_per_token
41+
return nparams, num_flops_per_token

0 commit comments

Comments
 (0)