Skip to content

Commit aaafb38

Browse files
[Device]Support npu (#6159)
* support npu * support pretrain support pretrain fix * support lora fix fix * support chatglm fix fxi fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci fix fix fix * Update train.py * Update train.py * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix * fix * fix * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent e994c64 commit aaafb38

File tree

18 files changed

+292
-149
lines changed

18 files changed

+292
-149
lines changed

applications/Colossal-LLaMA/colossal_llama/dataset/conversation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def dict(self):
100100
messages=[],
101101
offset=0,
102102
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
103-
seps=["<|begin_of_text|>", "<|end_of_text|>"],
103+
seps=["<|begin_of_text|>", "<|eot_id|>"],
104104
)
105105

106106
default_conversation = LLaMA3_Conv

applications/Colossal-LLaMA/colossal_llama/dataset/spliced_and_tokenized_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def supervised_tokenize_sft(
8888

8989
assert (
9090
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
91-
), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`."
91+
), f"`bos_token`{tokenizer.bos_token} and `eos_token`{tokenizer.eos_token} should be the same with `conversation_template.seps`{conversation_template.seps}."
9292

9393
if ignore_index is None:
9494
ignore_index = IGNORE_INDEX

applications/Colossal-LLaMA/colossal_llama/utils/ckpt_io.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def save_checkpoint(
4343
step: int,
4444
batch_size: int,
4545
coordinator: DistCoordinator,
46+
use_lora: bool = False,
4647
) -> None:
4748
"""
4849
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
@@ -51,7 +52,10 @@ def save_checkpoint(
5152
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
5253
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)
5354

54-
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
55+
if use_lora:
56+
booster.save_lora_as_pretrained(model, os.path.join(save_dir, "modeling"))
57+
else:
58+
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
5559

5660
booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
5761
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))

applications/Colossal-LLaMA/train.py

Lines changed: 61 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from colossal_llama.utils.froze import freeze_non_embeds_parameters
2222
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
2323
from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel
24+
from peft import LoraConfig
2425
from torch.utils.tensorboard import SummaryWriter
2526
from tqdm import tqdm
2627
from transformers import AutoModelForCausalLM, AutoTokenizer
@@ -65,7 +66,7 @@ def train(args) -> None:
6566
initial_scale=2**16,
6667
max_norm=args.grad_clip,
6768
enable_gradient_accumulation=(args.accumulation_steps > 1),
68-
enable_fused_normalization=torch.cuda.is_available(),
69+
enable_fused_normalization=get_accelerator().is_available(),
6970
enable_flash_attention=args.use_flash_attn,
7071
)
7172
elif args.plugin == "gemini_auto":
@@ -75,7 +76,7 @@ def train(args) -> None:
7576
initial_scale=2**16,
7677
max_norm=args.grad_clip,
7778
enable_gradient_accumulation=(args.accumulation_steps > 1),
78-
enable_fused_normalization=torch.cuda.is_available(),
79+
enable_fused_normalization=get_accelerator().is_available(),
7980
enable_flash_attention=args.use_flash_attn,
8081
)
8182
elif args.plugin == "zero2":
@@ -101,10 +102,9 @@ def train(args) -> None:
101102
sequence_parallelism_mode=args.sp_mode,
102103
zero_stage=args.zero_stage,
103104
enable_flash_attention=args.use_flash_attn,
104-
enable_fused_normalization=torch.cuda.is_available(),
105+
enable_fused_normalization=get_accelerator().is_available(),
105106
enable_sequence_parallelism=args.enable_sequence_parallelism,
106107
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
107-
parallel_output=False,
108108
max_norm=args.grad_clip,
109109
precision=args.mixed_precision,
110110
microbatch_size=args.microbatch_size,
@@ -117,11 +117,17 @@ def train(args) -> None:
117117
# ======================================================
118118
# Initialize Tokenizer, Dataset, Collator and Dataloader
119119
# ======================================================
120-
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
120+
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
121121
if args.pad_token == "eos":
122-
tokenizer.pad_token = tokenizer.eos_token
122+
try:
123+
tokenizer.pad_token = tokenizer.eos_token
124+
except AttributeError:
125+
coordinator.print_on_master(f"pad_token can't be set")
123126
elif args.pad_token == "unk":
124-
tokenizer.pad_token = tokenizer.unk_token
127+
try:
128+
tokenizer.pad_token = tokenizer.unk_token
129+
except AttributeError:
130+
coordinator.print_on_master(f"pad_token can't be set")
125131
tokenizer.add_bos_token = False
126132
tokenizer.add_eos_token = False
127133

@@ -164,33 +170,31 @@ def train(args) -> None:
164170
# ======================================================
165171
# Initialize Model, Objective, Optimizer and LR Scheduler
166172
# ======================================================
173+
# When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
167174
init_ctx = (
168175
LazyInitContext(default_device=get_current_device())
169-
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
176+
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0
170177
else nullcontext()
171178
)
172179
with init_ctx:
173-
if args.use_flash_attn:
174-
model = AutoModelForCausalLM.from_pretrained(
175-
args.pretrained,
176-
attn_implementation="flash_attention_2",
177-
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
178-
trust_remote_code=True,
179-
)
180-
else:
181-
model = AutoModelForCausalLM.from_pretrained(
182-
args.pretrained,
183-
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
184-
trust_remote_code=True,
185-
)
180+
model = AutoModelForCausalLM.from_pretrained(
181+
args.pretrained,
182+
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
183+
trust_remote_code=True,
184+
)
186185
# Freeze part of parameters.
187186
if args.freeze_non_embeds_params:
188187
freeze_non_embeds_parameters(model=model)
188+
189+
if args.lora_rank > 0:
190+
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=32, lora_dropout=0.1)
191+
model = booster.enable_lora(model, lora_config=lora_config)
192+
189193
# this is essential, otherwise the grad checkpoint will not work.
190194
model.train()
191195

192196
if args.use_grad_checkpoint:
193-
model.gradient_checkpointing_enable()
197+
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
194198
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
195199

196200
model_numel = get_model_numel(model)
@@ -327,6 +331,7 @@ def train(args) -> None:
327331
step=step + 1,
328332
batch_size=args.batch_size,
329333
coordinator=coordinator,
334+
use_lora=(args.lora_rank > 0),
330335
)
331336
coordinator.print_on_master(
332337
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
@@ -371,44 +376,45 @@ def train(args) -> None:
371376
total_loss.fill_(0.0)
372377
pbar.update()
373378

374-
# Save modeling.
375-
save_model_condition = (
376-
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
377-
)
379+
# Save modeling.
380+
save_model_condition = (
381+
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
382+
)
378383

379-
if not args.skip_save_each_epoch:
380-
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
384+
if not args.skip_save_each_epoch:
385+
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
381386

382-
if save_model_condition and not args.benchmark:
383-
coordinator.print_on_master("\nStart saving model checkpoint with running states")
387+
if save_model_condition and not args.benchmark:
388+
coordinator.print_on_master("\nStart saving model checkpoint with running states")
384389

385-
if args.use_neft:
386-
coordinator.print_on_master("Deactivate NEFTune before saving model.")
387-
deactivate_neftune(model, handle)
390+
if args.use_neft:
391+
coordinator.print_on_master("Deactivate NEFTune before saving model.")
392+
deactivate_neftune(model, handle)
388393

389-
accelerator.empty_cache()
390-
save_checkpoint(
391-
save_dir=args.save_dir,
392-
booster=booster,
393-
model=model,
394-
optimizer=optimizer,
395-
lr_scheduler=lr_scheduler,
396-
epoch=epoch,
397-
step=step + 1,
398-
batch_size=args.batch_size,
399-
coordinator=coordinator,
400-
)
401-
coordinator.print_on_master(
402-
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
403-
)
394+
accelerator.empty_cache()
395+
save_checkpoint(
396+
save_dir=args.save_dir,
397+
booster=booster,
398+
model=model,
399+
optimizer=optimizer,
400+
lr_scheduler=lr_scheduler,
401+
epoch=epoch,
402+
step=step + 1,
403+
batch_size=args.batch_size,
404+
coordinator=coordinator,
405+
use_lora=(args.lora_rank > 0),
406+
)
407+
coordinator.print_on_master(
408+
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
409+
)
404410

405-
if args.use_neft:
406-
coordinator.print_on_master("Activate NEFTune.")
407-
model, handle = activate_neftune(model)
411+
if args.use_neft:
412+
coordinator.print_on_master("Activate NEFTune.")
413+
model, handle = activate_neftune(model)
408414

409-
# Delete cache.
410-
# del batch, batch_labels, batch_output, loss
411-
accelerator.empty_cache()
415+
# Delete cache.
416+
# del batch, batch_labels, batch_output, loss
417+
accelerator.empty_cache()
412418

413419
# the continue epochs are not resumed, so we need to reset the sampler start index and start step
414420
dataloader.sampler.set_start_index(start_index=0)
@@ -522,6 +528,7 @@ def train(args) -> None:
522528
parser.add_argument(
523529
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
524530
)
531+
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")
525532

526533
# Additional arguments for benchmark.
527534
parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.")

colossalai/lazy/lazy_init.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -509,9 +509,9 @@ def wrap_factory_like_method(orig_target, target):
509509
# factory_like functions (eg. torch.empty_like())
510510
def wrapper(*args, **kwargs):
511511
orig_t = args[0]
512-
return self.tensor_cls(
513-
orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs
514-
)
512+
device = kwargs.pop("device", orig_t.device)
513+
dtype = kwargs.pop("dtype", orig_t.dtype)
514+
return self.tensor_cls(orig_target, *orig_t.shape, *args[1:], device=device, dtype=dtype, **kwargs)
515515

516516
return wrapper, target
517517

colossalai/legacy/communication/p2p.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def _communicate(
171171
for req in reqs:
172172
req.wait()
173173
# To protect against race condition when using batch_isend_irecv().
174-
torch.cuda.synchronize()
174+
get_accelerator().synchronize()
175175

176176
if recv_prev and recv_prev_split:
177177
if isinstance(tensor_recv_prev, torch.Tensor):

colossalai/pipeline/p2p.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from torch.distributed import distributed_c10d as c10d
1515
from torch.utils._pytree import tree_flatten, tree_unflatten
1616

17+
from colossalai.accelerator import get_accelerator
18+
1719
from .stage_manager import PipelineStageManager
1820

1921

@@ -31,7 +33,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
3133
buf = tensor.numpy().tobytes()[:tensor_size]
3234
if b"cuda" in buf:
3335
buf_array = bytearray(buf)
34-
device_index = torch.cuda.current_device()
36+
device_index = get_accelerator().current_device()
3537
# There might be more than one output tensors during forward
3638
for cuda_str in re.finditer(b"cuda", buf_array):
3739
pos = cuda_str.start()
@@ -86,7 +88,7 @@ def _broadcast_object_list(
8688
else:
8789
current_device = torch.device("cpu")
8890
if is_nccl_backend:
89-
current_device = torch.device("cuda", torch.cuda.current_device())
91+
current_device = torch.device("cuda", get_accelerator().current_device())
9092

9193
my_rank = dist.get_rank()
9294
# Serialize object_list elements to tensors on src rank.
@@ -139,29 +141,29 @@ def _broadcast_object_list(
139141
# unconsistence in device
140142
if (
141143
isinstance(unpickle_object, torch.Tensor)
142-
and unpickle_object.device.index != torch.cuda.current_device()
144+
and unpickle_object.device.index != get_accelerator().current_device()
143145
):
144-
unpickle_object = unpickle_object.cuda()
146+
unpickle_object = unpickle_object.to(get_accelerator().current_device())
145147

146148
object_list[i] = unpickle_object
147149

148150

149-
def _check_for_nccl_backend(group):
151+
def _check_for_nccl_hccl_backend(group):
150152
pg = group or c10d._get_default_group()
151153
# Gate PG wrapper check on Gloo availability.
152154
if c10d._GLOO_AVAILABLE:
153155
# It is not expected for PG to be wrapped many times, but support it just in case
154156
while isinstance(pg, c10d._ProcessGroupWrapper):
155157
pg = pg.wrapped_pg
156158

157-
return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
159+
return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and pg.name() == c10d.Backend.NCCL
158160

159161

160162
def _check_device(group):
161-
is_nccl_backend = _check_for_nccl_backend(group)
163+
is_nccl_backend = _check_for_nccl_hccl_backend(group)
162164
current_device = torch.device("cpu")
163165
if is_nccl_backend:
164-
current_device = torch.device("cuda", torch.cuda.current_device())
166+
current_device = torch.device(get_accelerator().current_device())
165167
return current_device, is_nccl_backend
166168

167169

@@ -348,8 +350,11 @@ def _send_recv_serialization_object(
348350

349351
unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())
350352

351-
if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
352-
unpickle_object = unpickle_object.cuda()
353+
if (
354+
isinstance(unpickle_object, torch.Tensor)
355+
and unpickle_object.device.index != get_accelerator().current_device()
356+
):
357+
unpickle_object = unpickle_object.to(get_accelerator().current_device())
353358

354359
return unpickle_object
355360

@@ -474,9 +479,11 @@ def _p2p_comm(
474479
recv_prev_shape = None
475480

476481
if tensor_send_next is not None:
477-
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
482+
send_next_shape = torch.tensor(
483+
tensor_send_next.size(), device=get_accelerator().current_device(), dtype=torch.int64
484+
)
478485
if recv_prev:
479-
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
486+
recv_prev_shape = torch.empty((3), device=get_accelerator().current_device(), dtype=torch.int64)
480487

481488
ops = []
482489
if send_next_shape is not None:
@@ -501,7 +508,7 @@ def _p2p_comm(
501508
# send and recv data
502509
tensor_recv_prev = None
503510
if recv_prev:
504-
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)
511+
tensor_recv_prev = torch.empty(recv_prev_shape, device=get_accelerator().current_device(), dtype=comm_dtype)
505512

506513
ops = []
507514
if tensor_send_next is not None:

colossalai/pipeline/schedule/interleaved_pp.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
33

44
import torch
5-
import torch.cuda
65
import torch.distributed
76
from torch.nn import Module, ModuleList
87
from torch.utils._pytree import tree_map
@@ -18,7 +17,7 @@
1817
from .base import PipelineSchedule
1918

2019

21-
def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
20+
def _wait_p2p(wait_handles) -> None:
2221
if wait_handles is not None:
2322
for req in wait_handles:
2423
req.wait()

colossalai/pipeline/schedule/one_f_one_b.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
from typing import Any, Callable, Dict, Iterable, List, Optional, Union
33

44
import torch
5-
import torch.cuda
65
from torch.nn import Module
76
from torch.utils._pytree import tree_map
87

0 commit comments

Comments
 (0)