Skip to content

Commit 24c7a0e

Browse files
committed
add
1 parent df172cd commit 24c7a0e

File tree

3 files changed

+39
-15
lines changed

3 files changed

+39
-15
lines changed

scripts/train.sh

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,19 @@
22

33
set -x
44

5+
# ARNOLD_WORKER_GPU=8
6+
# ARNOLD_WORKER_NUM=1
7+
# ARNOLD_ID=0
8+
59
# set dist args
6-
# SINGLE=1
10+
SINGLE=1
711
nproc_per_node=${ARNOLD_WORKER_GPU}
812

913
if [ ! -z "$SINGLE" ] && [ "$SINGLE" != "0" ]; then
1014
echo "[single node alone] SINGLE=$SINGLE"
1115
nnodes=1
1216
node_rank=0
13-
nproc_per_node=1
17+
nproc_per_node=8
1418
master_addr=127.0.0.1
1519
master_port=12345
1620
else
@@ -33,9 +37,13 @@ echo "[master_port: ${master_port}]"
3337

3438
# set up envs
3539
export OMP_NUM_THREADS=8
36-
export NCCL_IB_DISABLE=0
40+
export NCCL_IB_DISABLE=1
3741
export NCCL_IB_GID_INDEX=3
38-
export NCCL_SOCKET_IFNAME=eth0
42+
# export NCCL_SOCKET_IFNAME=xgbe0
43+
44+
# export NCCL_DEBUG=info
45+
# export NCCL_IB_DISABLE=1
46+
# export NCCL_P2P_DISABLE=1
3947

4048

4149
BED=checkpoints
@@ -60,12 +68,14 @@ local_out_path=$LOCAL_OUT/${exp_name}
6068
rm -rf ${bed_path}
6169
rm -rf ${local_out_path}
6270

63-
torchrun \
64-
--nproc_per_node=${nproc_per_node} \
65-
--nnodes=${nnodes} \
66-
--node_rank=${node_rank} \
67-
--master_addr=${master_addr} \
68-
--master_port=${master_port} \
71+
# torchrun \
72+
# --nproc_per_node=${nproc_per_node} \
73+
# --nnodes=${nnodes} \
74+
# --node_rank=${node_rank} \
75+
# --master_addr=${master_addr} \
76+
# --master_port=${master_port} \
77+
python -m torch.distributed.launch --nproc-per-node=8 \
78+
--local_ranks_filter 0 \
6979
train.py \
7080
--ep=100 \
7181
--opt=adamw \
@@ -96,14 +106,14 @@ train.py \
96106
--use_streaming_dataset 1 \
97107
--iterable_data_buffersize 30000 \
98108
--Ct5=2048 \
99-
--t5_path=weights/flan-t5-xl \
109+
--t5_path=google/flan-t5-xl \
100110
--vae_type 32 \
101111
--vae_ckpt=weights/infinity_vae_d32_rdn_short.pth \
102112
--wp 0.00000001 \
103113
--wpe=1 \
104114
--dynamic_resolution_across_gpus 1 \
105115
--enable_dynamic_length_prompt 1 \
106-
--reweight_loss_by_scale 1 \
116+
--reweight_loss_by_scale 0 \
107117
--add_lvl_embeding_only_first_block 1 \
108118
--rope2d_each_sa_layer 1 \
109119
--rope2d_normalized_by_hw 2 \

train.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -327,6 +327,7 @@ def main_train(args: arg_util.Args):
327327
# build wandb logger
328328
if dist.is_master():
329329
wandb_utils.wandb.init(project=args.project_name, name=args.exp_name, config={})
330+
330331
for ep in range(start_ep, args.ep):
331332
if ep % ep_lg == 0 or ep == start_ep:
332333
print(f'[PT info] from ep{start_ep} it{start_it}, acc_str: {acc_str}, diffs: {args.diffs}, =======> bed: {args.bed} <=======\n')
@@ -483,10 +484,15 @@ def train_one_ep(
483484
with maybe_record_function('before_train'):
484485
# [get data]
485486
inp, captions = data
486-
tokens = text_tokenizer(text=captions, max_length=text_tokenizer.model_max_length, padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
487+
tokens = text_tokenizer(text=captions, max_length=text_tokenizer.model_max_length,
488+
padding='max_length', truncation=True, return_tensors='pt') # todo: put this into dataset
489+
print("gongwb tokens:", tokens)
490+
487491
input_ids = tokens.input_ids.cuda(non_blocking=True)
488492
mask = tokens.attention_mask.cuda(non_blocking=True)
493+
489494
text_features = text_encoder(input_ids=input_ids, attention_mask=mask)['last_hidden_state'].float()
495+
print("gongwb text_features:", text_features)
490496

491497
lens: List[int] = mask.sum(dim=-1).tolist()
492498
cu_seqlens_k = F.pad(mask.sum(dim=-1).to(dtype=torch.int32).cumsum_(0), (1, 0))
@@ -521,7 +527,8 @@ def train_one_ep(
521527
step_cnt += int(stepping)
522528

523529
with maybe_record_function('in_training'):
524-
grad_norm_t, scale_log2_t = trainer.train_step(
530+
#grad_norm_t, scale_log2_t =
531+
trainer.train_step(
525532
ep=ep, it=it, g_it=g_it, stepping=stepping, clip_decay_ratio=clip_decay_ratio,
526533
metric_lg=me,
527534
logging_params=stepping and step_cnt == 1 and (ep < 4 or ep in logging_params_milestone),

trainer.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ def train_step(
159159
V = self.vae_local.vocab_size
160160
device = inp_B3HW.device
161161

162+
print(f"gongwb B: {B}, T: {T}, V:{V}")
163+
print("gongwb inp_B3HW:", inp_B3HW)
164+
162165
h_div_w = inp_B3HW.shape[-2] / inp_B3HW.shape[-1]
163166
h_div_w_templates = np.array(list(dynamic_resolution_h_w.keys()))
164167
h_div_w_template = h_div_w_templates[np.argmin(np.abs(h_div_w-h_div_w_templates))]
@@ -184,7 +187,10 @@ def train_step(
184187
x_BLC_wo_prefix = x_BLC_wo_prefix[:, :(training_seq_len-np.array(scale_schedule[0]).prod()), :]
185188

186189
self.gpt_wo_ddp.forward
187-
logits_BLV = self.gpt(text_cond_tuple, x_BLC_wo_prefix, scale_schedule=scale_schedule[:training_scales]) # [bs, 1*1+...+64*64, vocab_size or log2(vocab_size)*2]
190+
logits_BLV = self.gpt(text_cond_tuple,
191+
x_BLC_wo_prefix, scale_schedule=scale_schedule[:training_scales]) # [bs, 1*1+...+64*64, vocab_size or log2(vocab_size)*2]
192+
print("gongwb self.gpt:", self.gpt)
193+
print(f"gongwb logits_BLV:{logits_BLV.shape}")
188194
self.batch_size, self.seq_len = logits_BLV.shape[:2]
189195

190196
self.seq_len_each = [idx_Bl.shape[1] for idx_Bl in gt_ms_idx_Bl]
@@ -214,6 +220,7 @@ def train_step(
214220
lw = 1. / self.seq_len
215221
loss = loss.mul(lw).sum(dim=-1).mean()
216222

223+
return
217224
# [backward]
218225
grad_norm_t, scale_log2_t = self.gpt_opt.backward_clip_step(ep=ep, it=it, g_it=g_it, stepping=stepping, logging_params=logging_params, loss=loss, clip_decay_ratio=clip_decay_ratio, stable=args.stable)
219226

0 commit comments

Comments
 (0)