Skip to content

Commit 501a5f6

Browse files
committed
add only_packing_no_scheduling for hybrid-cp
Signed-off-by: tailaim <[email protected]>
1 parent 1167117 commit 501a5f6

File tree

16 files changed

+339
-421
lines changed

16 files changed

+339
-421
lines changed

megatron/core/distributed/fsdp/src/megatron_fsdp/param_and_grad_buffer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3342,9 +3342,6 @@ def wait_bucket_ready(self, bucket_id, empty_ok=False):
33423342
# Wait for asynchronous / overlapped NCCL operations to complete.
33433343
param_gather_event, mark_bucket_ready_to_use = self.param_gather_event_map.pop(bucket_id)
33443344
param_gather_event.wait()
3345-
# debugmtl
3346-
if self.ag_stream is not None:
3347-
torch.cuda.current_stream().wait_stream(self.ag_stream)
33483345
mark_bucket_ready_to_use()
33493346

33503347
@torch.no_grad()

megatron/core/model_parallel_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,14 @@ class ModelParallelConfig:
6969
When enabling hybrid_context_parallel, sft_sequence_packing must be true.
7070
"""
7171

72+
hybrid_context_parallel_scheduler: str = 'balanced'
73+
"""
74+
Scheduler for hybrid context parallel.
75+
balanced: balanced scheduler for hybrid context parallel.
76+
only_packing_no_scheduling: scheduling is already handled by the data sampler,
77+
this scheduler only performs packing.
78+
"""
79+
7280
sft_sequence_packing: bool = False
7381
"""
7482
If true, enables sft sequence packing.

megatron/core/models/gpt/gpt_model.py

Lines changed: 0 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -37,65 +37,6 @@
3737
from megatron.core.transformer.transformer_config import TransformerConfig
3838
from megatron.core.utils import WrappedTensor, deprecate_inference_params
3939

40-
# #debugmtl
41-
# _ACT_GRAD_DUMP_COUNTERS = {}
42-
43-
# def _sanitize_name(name: str) -> str:
44-
# return str(name).replace('/', '_').replace('\\', '_').replace('.', '_').replace(' ', '_')
45-
46-
# def _next_act_dump_index(rank: int, layer_name: str) -> int:
47-
# key = (rank, layer_name)
48-
# cnt = _ACT_GRAD_DUMP_COUNTERS.get(key, 0) + 1
49-
# _ACT_GRAD_DUMP_COUNTERS[key] = cnt
50-
# return cnt
51-
52-
# def get_debug_hook(layer_name: str):
53-
# """
54-
# Tensor-level grad hook: save activation grad by (rank, layer_name, index).
55-
# """
56-
# import os
57-
# def hook(grad: torch.Tensor):
58-
# if grad is None:
59-
# return
60-
61-
# rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else 0
62-
63-
# # 基础目录自行改成你想要的
64-
# base_dir = "/home/tailaim/act_grad_dump"
65-
# if not base_dir:
66-
# return
67-
68-
# try:
69-
# idx = _next_act_dump_index(rank, layer_name)
70-
# layer_dir = os.path.join(
71-
# base_dir,
72-
# f"rank_{rank}",
73-
# _sanitize_name(layer_name),
74-
# )
75-
# os.makedirs(layer_dir, exist_ok=True)
76-
# file_path = os.path.join(layer_dir, f"grad_{idx:06d}.pt")
77-
78-
# # 只前几次写盘,避免太多文件
79-
# if idx <= 16:
80-
# torch.save(grad.detach().cpu(), file_path)
81-
82-
# # 只在第一次写时打印一行日志
83-
# if idx == 1:
84-
# try:
85-
# g_shape = tuple(grad.shape)
86-
# g_dtype = str(grad.dtype)
87-
# except Exception:
88-
# g_shape = "unknown"
89-
# g_dtype = "unknown"
90-
# print(
91-
# f"[Rank {rank}] Saved act grad: layer={layer_name}, "
92-
# f"idx={idx:06d}, shape={g_shape}, dtype={g_dtype}, path={file_path}"
93-
# )
94-
# except Exception as e:
95-
# print(f"[Rank {rank}] act grad dump failed for {layer_name}: {e}")
96-
97-
# return hook
98-
9940

10041
class GPTModel(LanguageModule):
10142
"""GPT Transformer language model.

megatron/core/parallel_state.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,6 +559,7 @@ def initialize_model_parallel(
559559
create_gloo_process_groups: bool = True,
560560
high_priority_stream_groups: Optional[List[str]] = None,
561561
sharp_enabled_group: Optional[str] = None,
562+
min_hybrid_context_parallel_size: int = 1,
562563
) -> None:
563564
"""Initialize model data parallel groups.
564565
@@ -973,7 +974,12 @@ def initialize_model_parallel(
973974
if hybrid_context_parallel:
974975
# PyTorch is performing lazy initialization of the communicator group.
975976
# Therefore, we need to perform a nccl call to ensure that the communicator group is created.
976-
group_sizes = [2**i for i in range(int(log2(data_parallel_size)))]
977+
group_sizes = [
978+
2**i
979+
for i in range(
980+
int(log2(min_hybrid_context_parallel_size)), int(log2(data_parallel_size))
981+
)
982+
]
977983
if group_sizes[-1] * 2 == data_parallel_size:
978984
group_sizes.append(data_parallel_size)
979985
for group_size in group_sizes:

0 commit comments

Comments
 (0)