Skip to content

Commit 1e8285c

Browse files
committed
new
1 parent 3a665e8 commit 1e8285c

File tree

11 files changed

+533
-26
lines changed

11 files changed

+533
-26
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,4 @@ llm_ckpts
147147
events.*
148148
memory_trace
149149
RUN*/
150+
micro_record/

configs/7B_isp_sft.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
DO_ALERT = False
44

55
VOCAB_SIZE = 103168
6-
SEQ_LEN = 2048
6+
SEQ_LEN = 16*1024
77
HIDDEN_SIZE = 4096
88
NUM_ATTENTION_HEAD = 32
99
NUM_KV_ATTENTION_HEAD = 8
1010
MLP_RATIO = 8 / 3
1111
NUM_LAYER = 32
12-
BUCKET_SIZE = 512
12+
BUCKET_SIZE = 256
1313

1414

1515
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
@@ -53,18 +53,19 @@
5353
TRAIN_FOLDER = '/data/wikipedia/en_test/train_test_dataset' # "/path/to/dataset"
5454
VALID_FOLDER = None # "/path/to/dataset"
5555
data = dict(
56+
data_name="wiki",
5657
seq_len=SEQ_LEN,
5758
bucket_size=BUCKET_SIZE,
5859
# micro_num means the number of micro_batch contained in one gradient update
59-
micro_num=4,
60+
micro_num=16,
6061
# packed_length = micro_bsz * SEQ_LEN
6162
micro_bsz=2,
6263
# defaults to the value of micro_num
6364
valid_micro_num=4,
6465
# defaults to 0, means disable evaluate
6566
valid_every=0,
6667
pack_sample_into_one=False,
67-
total_steps=50,
68+
total_steps=42,
6869
skip_batches="",
6970
# rampup_batch_size (str): A string with three space-separated integers representing the
7071
# starting batch size, the increment, and the number of steps between
@@ -230,11 +231,15 @@
230231
interleaved: bool, if `head_first` is `False` and `window_size` > 1, this config could
231232
interleaved the ranks in the same window to make full use of NIC as much as possible.
232233
"""
234+
# wdp = world_size // wp // pp # isp
235+
# dp = world_size // tp // pp
236+
# zero1 size is up to wdp
237+
233238
parallel = dict(
234239
zero1=dict(size=-1),
235240
tensor=dict(size=2, mode="isp"),
236-
pipeline=dict(size=1, interleaved_overlap=True),
237-
weight=dict(size=4, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),
241+
pipeline=dict(size=4, interleaved_overlap=True),
242+
weight=dict(size=2, overlap=True, launch_allgather_before="wo", forward_overlap_per="layer"),
238243
sequence_2D=dict(
239244
enable=False,
240245
head_size=2,
@@ -246,7 +251,7 @@
246251

247252
cudnn_deterministic = False
248253
cudnn_benchmark = False
249-
254+
profile_fwd_bwd = True
250255

251256
# monitor = dict(
252257
# # feishu alert configs

configs/7B_llama2.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
DO_ALERT = False
44

55
VOCAB_SIZE = 32000
6-
SEQ_LEN = 2048
7-
HIDDEN_SIZE = 4096
8-
NUM_ATTENTION_HEAD = 32
9-
NUM_KV_ATTENTION_HEAD = 32
10-
MLP_RATIO = 2.6875
11-
NUM_LAYER = 32
12-
6+
SEQ_LEN = 16*1024
7+
HIDDEN_SIZE = 5120
8+
NUM_ATTENTION_HEAD = 40
9+
NUM_KV_ATTENTION_HEAD = 40
10+
MLP_RATIO = 2.7
11+
NUM_LAYER = 40
12+
BUCKET_SIZE = 512
1313

1414
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
1515
# Ckpt folder format:

internlm/core/context/parallel_context.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ def __init__(self):
170170
self._expert_parallel_group_names = []
171171
self.is_evaluating = False
172172
self.v_shape = False
173+
self.batch_count = 1
173174

174175
@property
175176
def config(self):
@@ -516,7 +517,7 @@ def _set_parallel_size_from_config(self, config: dict, key: str, attr_name: str)
516517

517518
def init_parallel_groups(self):
518519
"""Initializes the parallel groups."""
519-
520+
520521
# get rank and world size
521522
rank = self.get_global_rank()
522523
world_size = self.get_world_size(ParallelMode.GLOBAL)

internlm/core/scheduler/pipeline_scheduler_1f1b.py

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88

99
import torch
1010
import torch.distributed as dist
11+
import time
12+
import os
13+
import json
1114

1215
from internlm.core.context import ParallelMode
1316
from internlm.core.context import global_context as gpc
@@ -202,22 +205,39 @@ def _call_engine(engine, data): # pylint: disable=W0237
202205
def load_batch(self, engine, data_iter):
203206
# Pipeline schedule just puts data in memory,
204207
batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=False)
205-
208+
batch_seqlist = []
209+
# import pdb
210+
# pdb.set_trace()
206211
# Even if 'use_flash_attn' is False, the data seen when the 'load_batch' is called is still packed,
207212
# because internlm's current train dataset is packed, even using dummy data.
208213
# The unpack operation is performed in load_micro_batch().
209214
if check_data_is_packed(batch_data):
210215
micro_num = actual_batch_size
211216
else:
212217
micro_num = actual_batch_size // gpc.config.data["micro_bsz"]
213-
218+
# import pdb
219+
# breakpoint()
220+
for micro_batch_cu in batch_data[0]['cu_seqlens']:
221+
micro_batch_seqlist = [ int(micro_batch_cu[j]) - int(micro_batch_cu[j - 1]) for j in range(1, len(micro_batch_cu))]
222+
batch_seqlist.append(micro_batch_seqlist)
223+
214224
self.microbatch_offset = 0
215225
self.batch_size = actual_batch_size
216226
self.batch_data, self.batch_label = batch_data
217227
self.bsz_stride = self.batch_size // micro_num
218228
# 'num_microbatches' is no longer an initialization parameter,
219229
# but is determined on the fly by the Scheduler.
220230
self.num_microbatches = micro_num # Rampup or variable bsz size.
231+
232+
if gpc.config.profile_fwd_bwd and os.environ.get("CUDA_LAUNCH_BLOCKING") == "1" and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
233+
output_dir = os.path.join("./micro_record", gpc.config.data.data_name, f"B{gpc.config.data.bucket_size}_seq{gpc.config.SEQ_LEN}_mb{gpc.config.data.micro_num}", f'S{gpc.batch_count}')
234+
os.makedirs(output_dir, exist_ok=True)
235+
output_file = os.path.join(output_dir, f"PP_rank_{gpc.get_local_rank(ParallelMode.PIPELINE)}_seq.json")
236+
237+
with open(output_file, "w") as f:
238+
for micro_batch_seqlist in batch_seqlist:
239+
json.dump(micro_batch_seqlist, f)
240+
f.write('\n')
221241

222242
def load_micro_batch(self):
223243
micro_batch_data, micro_batch_label = self._load_micro_batch(
@@ -592,8 +612,12 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
592612
input_obj = None
593613

594614
# Run 1F1B in steady state.
615+
fwd_times = []
616+
bwd_times = []
617+
595618
for i in range(num_1f1b_micropairs):
596619
# Perform forward computation
620+
start_time=time.time()
597621
output_obj, moe_loss = self._forward_step(
598622
engine,
599623
input_obj,
@@ -602,6 +626,7 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
602626
accum_loss=accum_loss,
603627
accum_moe_loss=accum_moe_loss,
604628
)
629+
fwd_times.append(time.time() - start_time)
605630

606631
if gpc.is_last_rank(ParallelMode.PIPELINE):
607632
output_obj_grad = None
@@ -625,7 +650,9 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
625650
output_obj = output_objs.pop(0)
626651
moe_loss = moe_losses.pop(0)
627652

653+
start_bwd_time=time.time()
628654
input_obj_grad = self._backward_step(engine, i, input_obj, output_obj, output_obj_grad, moe_loss)
655+
bwd_times.append(time.time() - start_bwd_time)
629656

630657
if i == (num_1f1b_micropairs - 1):
631658
input_obj = None
@@ -644,6 +671,44 @@ def _forward_backward_step(self, engine, return_loss=True, return_output_label=T
644671
dtype=self.dtype,
645672
scatter_gather_tensors=self.scatter_gather_tensors,
646673
)
674+
if gpc.config.profile_fwd_bwd and os.environ.get("CUDA_LAUNCH_BLOCKING") == "1" and gpc.get_local_rank(ParallelMode.DATA) == 0 and gpc.get_local_rank(ParallelMode.TENSOR) == 0:
675+
output_dir = os.path.join("./micro_record", gpc.config.data.data_name, f"B{gpc.config.data.bucket_size}_seq{gpc.config.SEQ_LEN}_mb{gpc.config.data.micro_num}", f'S{gpc.batch_count}')
676+
os.makedirs(output_dir, exist_ok=True)
677+
output_file = os.path.join(output_dir, f"PP_rank_{gpc.get_local_rank(ParallelMode.PIPELINE)}.json")
678+
gpc.batch_count += 1
679+
680+
history = {
681+
"fwd_times": [],
682+
"bwd_times": [],
683+
}
684+
685+
# 2. 如果文件存在,则读取旧数据
686+
if os.path.exists(output_file):
687+
with open(output_file, 'r') as f:
688+
try:
689+
history = json.load(f)
690+
except json.JSONDecodeError:
691+
pass # 文件为空或损坏则跳过
692+
693+
# 3. 追加新数据
694+
history["fwd_times"].extend(fwd_times)
695+
history["bwd_times"].extend(bwd_times)
696+
697+
from collections import OrderedDict
698+
data = OrderedDict()
699+
# 4. 更新平均值
700+
data["avg_fwd"] = sum(history["fwd_times"]) / len(history["fwd_times"])
701+
data["avg_bwd"] = sum(history["bwd_times"]) / len(history["bwd_times"])
702+
f_f = round(data["avg_fwd"]/data["avg_fwd"],3)
703+
b_f = round(data["avg_bwd"]/data["avg_fwd"],3)
704+
data["f_b_w"] = (f_f, b_f)
705+
data["fwd_times"] = history["fwd_times"]
706+
data["bwd_times"] = history["bwd_times"]
707+
708+
# 5. 写回文件
709+
with open(output_file, 'w') as f:
710+
json.dump(data, f, indent=4)
711+
647712

648713
# Run cooldown backward passes.
649714
for i in range(num_warmup_microsteps):

internlm/data/build_dataloader.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from internlm.data.tokenized.batch_sampler import (
2424
StaticBatchSampler,
2525
get_dpsampler_dataloader,
26+
BucketGroupBatchSampler,
2627
)
2728
from internlm.data.tokenized.collaters import (
2829
generation_collate_fn,
@@ -86,8 +87,9 @@ def get_tokenized_train_loader_items(data_cfg):
8687
pack_sample_into_one=data_cfg.get("pack_sample_into_one", False),
8788
bucket_size=data_cfg.get("bucket_size", 0)
8889
)
89-
90-
train_sampler = StaticBatchSampler(
90+
if data_cfg.get("bucket_size", 0) > 0:
91+
enable_bucket_balance = True
92+
train_sampler = BucketGroupBatchSampler(
9193
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
9294
batch_size=data_cfg.micro_num,
9395
rampup_batch_size=data_cfg.rampup_batch_size,
@@ -96,7 +98,19 @@ def get_tokenized_train_loader_items(data_cfg):
9698
drop_last=True,
9799
data_rank=gpc.get_local_rank(ParallelMode.DATA),
98100
data_world_size=gpc.get_world_size(ParallelMode.DATA),
99-
)
101+
enable_bucket_balance=enable_bucket_balance,
102+
)
103+
else:
104+
train_sampler = StaticBatchSampler(
105+
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
106+
batch_size=data_cfg.micro_num,
107+
rampup_batch_size=data_cfg.rampup_batch_size,
108+
micro_bsz=data_cfg.micro_bsz,
109+
seed=data_cfg.get("seed", 1024),
110+
drop_last=True,
111+
data_rank=gpc.get_local_rank(ParallelMode.DATA),
112+
data_world_size=gpc.get_world_size(ParallelMode.DATA),
113+
)
100114
train_collate_fn = partial(packed_collate_fn, packed_length=data_cfg.packed_length)
101115

102116
return train_ds, train_sampler, train_collate_fn

0 commit comments

Comments
 (0)