Skip to content

Commit 3f4818c

Browse files
committed
[feat] support hybrid test
1 parent d3fd485 commit 3f4818c

File tree

1 file changed

+172
-0
lines changed

1 file changed

+172
-0
lines changed
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import torch
2+
import torch.distributed as dist
3+
from coati.dataset.loader import RawConversationDataset
4+
from torch.utils.data import Dataset
5+
from tqdm import tqdm
6+
from transformers import AutoTokenizer, Qwen2ForCausalLM
7+
8+
import colossalai
9+
from colossalai.accelerator import get_accelerator
10+
from colossalai.booster import Booster
11+
from colossalai.booster.plugin import HybridParallelPlugin, Plugin
12+
from colossalai.cluster import DistCoordinator
13+
from colossalai.nn.optimizer import HybridAdam
14+
15+
# 定义训练参数
16+
BATCH_SIZE = 8
17+
NUM_EPOCHS = 3
18+
LEARNING_RATE = 2e-5
19+
GRADIENT_ACCUMULATION_STEPS = 1
20+
DATA_PATH = "/home/duanjunwen/datasets/math_dataset.jsonl"
21+
Device = torch.device("npu" if torch.npu.is_available() else "cpu")
22+
23+
24+
class RandomDataset(Dataset):
25+
def __init__(self, num_samples, sequence_length, vocab_size=10000):
26+
self.num_samples = num_samples
27+
self.sequence_length = sequence_length
28+
self.vocab_size = vocab_size
29+
self.input_idx = torch.randint(0, vocab_size, (num_samples, sequence_length))
30+
self.attention_mask = torch.randint(0, 2, (num_samples, sequence_length), dtype=torch.long)
31+
32+
def __len__(self):
33+
return self.num_samples
34+
35+
def __getitem__(self, idx):
36+
return {"input_ids": self.input_idx[idx], "attention_mask": self.attention_mask[idx]}
37+
38+
39+
def load_model_and_tokenizer():
40+
attn_impl = "eager" if get_accelerator().name == "npu" else "flash_attention_2"
41+
tokenizer = AutoTokenizer.from_pretrained(
42+
"/home/duanjunwen/models/Qwen/Qwen2.5-3B",
43+
trust_remote_code=True,
44+
attn_implementation=attn_impl,
45+
)
46+
model = Qwen2ForCausalLM.from_pretrained("/home/duanjunwen/models/Qwen/Qwen2.5-3B", trust_remote_code=True)
47+
return tokenizer, model
48+
49+
50+
def all_reduce_mean(loss: torch.Tensor, plugin: Plugin) -> torch.Tensor:
51+
loss = loss.data
52+
group = getattr(plugin, "dp_group", None)
53+
dist.all_reduce(loss, group=group)
54+
return loss / dist.get_world_size(group)
55+
56+
57+
# def train(model, dataloader, booster, optimizer):
58+
# model.train()
59+
60+
# for epoch in range(NUM_EPOCHS):
61+
# if booster.plugin.pp_size > 1:
62+
# data_iter = iter(dataloader)
63+
# step_bar = tqdm(
64+
# range(len(dataloader)),
65+
# desc="Step",
66+
# disable=not is_master(),
67+
# )
68+
# else:
69+
# total_loss = 0
70+
# for step, batch in enumerate(dataloader):
71+
# input_ids = batch["input_ids"].to(device=model.module.device)
72+
# attention_mask = batch["attention_mask"].to(device=model.module.device)
73+
# outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
74+
# loss = outputs.loss
75+
# print(f"loss {loss} outputs {outputs}")
76+
# loss = loss / GRADIENT_ACCUMULATION_STEPS
77+
# booster.backward(loss, optimizer)
78+
79+
# if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
80+
# optimizer.step()
81+
# optimizer.zero_grad()
82+
83+
# total_loss += loss.item()
84+
85+
# print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}")
86+
87+
88+
def test_hybrid_qwen():
89+
colossalai.launch_from_torch()
90+
get_accelerator()
91+
coordinator = DistCoordinator()
92+
tokenizer, model = load_model_and_tokenizer()
93+
# dataset = RandomDataset(num_samples=100, sequence_length=2304)
94+
dataset = RawConversationDataset(tokenizer, DATA_PATH, 1024)
95+
# dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
96+
97+
optimizer = HybridAdam(model.parameters(), lr=LEARNING_RATE)
98+
plugin = HybridParallelPlugin(tp_size=2, pp_size=1, precision="bf16", zero_stage=2)
99+
# plugin = HybridParallelPlugin(tp_size=2, pp_size=2, precision="bf16", zero_stage=1, num_microbatches=4, enable_flash_attention=True)
100+
101+
dataloader = plugin.prepare_dataloader(
102+
dataset=dataset,
103+
batch_size=BATCH_SIZE,
104+
shuffle=True,
105+
drop_last=True,
106+
)
107+
108+
booster = Booster(plugin=plugin)
109+
110+
model, optimizer, _, dataloader, _ = booster.boost(model, optimizer, None, dataloader)
111+
112+
def is_master():
113+
if isinstance(plugin, HybridParallelPlugin) and plugin.pp_size > 1:
114+
return coordinator.rank == coordinator.world_size - 1
115+
return coordinator.is_master()
116+
117+
#####
118+
# train
119+
#####
120+
model.train()
121+
122+
for epoch in range(NUM_EPOCHS):
123+
if booster.plugin.pp_size > 1:
124+
data_iter = iter(dataloader)
125+
step_bar = tqdm(
126+
range(len(dataloader)),
127+
desc="Step",
128+
disable=not is_master(),
129+
)
130+
for step in step_bar:
131+
print(f"data_iter {data_iter}")
132+
outputs = booster.execute_pipeline(
133+
data_iter,
134+
model,
135+
criterion=lambda outputs, inputs: outputs[0],
136+
optimizer=optimizer,
137+
return_loss=True,
138+
)
139+
loss = outputs["loss"]
140+
if booster.plugin.stage_manager.is_last_stage():
141+
global_loss = all_reduce_mean(loss, plugin)
142+
143+
optimizer.step()
144+
145+
if booster.plugin.stage_manager.is_last_stage():
146+
grad_norm = optimizer.get_grad_norm()
147+
step_bar.set_postfix({"loss": global_loss.item(), "grad_norm": grad_norm})
148+
149+
optimizer.step()
150+
optimizer.zero_grad()
151+
else:
152+
total_loss = 0
153+
for step, batch in enumerate(dataloader):
154+
input_ids = batch["input_ids"].to(device=model.module.device)
155+
attention_mask = batch["attention_mask"].to(device=model.module.device)
156+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=input_ids)
157+
loss = outputs.loss
158+
print(f"loss {loss} outputs {outputs}")
159+
loss = loss / GRADIENT_ACCUMULATION_STEPS
160+
booster.backward(loss, optimizer)
161+
162+
if (step + 1) % GRADIENT_ACCUMULATION_STEPS == 0:
163+
optimizer.step()
164+
optimizer.zero_grad()
165+
166+
total_loss += loss.item()
167+
168+
print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}")
169+
170+
171+
if __name__ == "__main__":
172+
test_hybrid_qwen()

0 commit comments

Comments
 (0)