Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions megatron/core/distributed/distributed_data_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,73 @@ def hook(*unused):
param.main_grad.add_(param.grad.data)
param.grad = None

# Nonuniform TP: gather grads from spare GPUs and scatter to core GPUs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Inherit from DDP, make a new class and override _make_backward_post_hook().

if (
self.ddp_config.tp_spares > 0
and hasattr(param, 'tensor_model_parallel')
and param.tensor_model_parallel
and parallel_state.get_tensor_model_parallel_world_size()
== self.ddp_config.tp_base
):
empty_shape = list(param.shape)
empty_shape[param.partition_dim] = 0
tp_rank = parallel_state.get_tensor_model_parallel_rank()

if tp_rank < self.ddp_config.tp_base - self.ddp_config.tp_spares:
# Core GPU: receive grads from spare GPUs
input = [
torch.empty(
empty_shape, device=param.device, dtype=param.side_grad.dtype
).contiguous()
for _ in range(parallel_state.get_tensor_model_parallel_world_size())
]
# Split side_grad and send to core GPUs
output = [
torch.empty(
empty_shape, device=param.device, dtype=param.side_grad.dtype
).contiguous()
for _ in range(self.ddp_config.tp_base - self.ddp_config.tp_spares)
] + [
t.contiguous()
for t in (print(f"[Rank {torch.distributed.get_rank()}] Core GPU accessing recv_splits: param id={id(param)}, has_recv={hasattr(param, 'recv_splits')}, has_send={hasattr(param, 'send_splits')}, tp_rank={tp_rank}") or torch.split(
param.side_grad, param.recv_splits[tp_rank], dim=param.partition_dim
))
][-self.ddp_config.tp_spares :]
else:
# Spare GPU: send grads to core GPUs
print(f"[Rank {torch.distributed.get_rank()}] Spare GPU accessing send_splits: param id={id(param)}, has_recv={hasattr(param, 'recv_splits')}, has_send={hasattr(param, 'send_splits')}, tp_rank={tp_rank}")
input = [
t.contiguous()
for t in torch.split(
param.main_grad, param.send_splits[tp_rank], dim=param.partition_dim
)
]
output = [
torch.empty(
empty_shape, device=param.device, dtype=param.main_grad.dtype
).contiguous()
for _ in range(parallel_state.get_tensor_model_parallel_world_size())
]

try:
torch.distributed.all_to_all(
output,
input,
group=parallel_state.get_tensor_model_parallel_group(),
async_op=True,
)
except Exception as e:
print('rank %d error: ' % tp_rank, e)
print(
'rank %d input element contiguity: ' % tp_rank,
[i.is_contiguous() for i in input],
)
print(
'rank %d output element contiguity: ' % tp_rank,
[o.is_contiguous() for o in output],
)
raise e

if self.ddp_config.overlap_grad_reduce:
self.param_to_bucket_group[param].register_grad_ready(param)

Expand Down
24 changes: 23 additions & 1 deletion megatron/core/distributed/distributed_data_parallel_config.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

from dataclasses import dataclass
from typing import Optional
from typing import Dict, List, Optional, Tuple


@dataclass
Expand Down Expand Up @@ -140,6 +140,28 @@ class DistributedDataParallelConfig:
delay_wgrad_compute: bool = False
"""Delay the weight gradient computation to improve batch-level communication overlapping"""

tp_base: int = 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make a small config class just for NTP.

"""Base for tensor parallelism. This is the number of ranks in healthy tensor parallel groups.
Used for nonuniform tensor parallelism."""

tp_spares: int = 0
"""Number of spares for nonuniform tensor parallelism. When > 0, enables nonuniform TP mode
where (tp_base - tp_spares) ranks handle computation and tp_spares ranks provide fault tolerance."""

num_reduced_tp_dp_ranks: int = 1
"""Number of DP ranks that use reduced TP (tp_base - tp_spares). The remaining DP ranks use
full tp_base. Reduced TP ranks are assumed to come first in the global rank ordering."""

non_active_ranks_per_dp: Optional[Dict[Tuple[int, int, int], List[int]]] = None
"""Mapping of (DP rank, CP rank, PP rank) to list of non-active (spare) local TP rank IDs.
This allows specifying arbitrary GPU failures across all parallelism dimensions.
Example: {(0,0,0): [0,3], (0,1,0): [1,2], (1,0,0): [0,3]} means:
- DP rank 0, CP rank 0, PP rank 0 has local TP ranks 0,3 as spares
- DP rank 0, CP rank 1, PP rank 0 has local TP ranks 1,2 as spares
- DP rank 1, CP rank 0, PP rank 0 has local TP ranks 0,3 as spares
The number of non-active ranks must be consistent across CP replicas within each DP rank.
If None, defaults to last tp_spares ranks as non-active."""

def __post_init__(self):
import os

Expand Down
Loading
Loading