Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def dict(self):
messages=[],
offset=0,
sep_style=SeparatorStyle.ADD_BOS_EOS_TOKEN,
seps=["<|begin_of_text|>", "<|end_of_text|>"],
seps=["<|begin_of_text|>", "<|eot_id|>"],
)

default_conversation = LLaMA3_Conv
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def supervised_tokenize_sft(

assert (
tokenizer.bos_token == conversation_template.seps[0] and tokenizer.eos_token == conversation_template.seps[1]
), "`bos_token` and `eos_token` should be the same with `conversation_template.seps`."
), f"`bos_token`{tokenizer.bos_token} and `eos_token`{tokenizer.eos_token} should be the same with `conversation_template.seps`{conversation_template.seps}."

if ignore_index is None:
ignore_index = IGNORE_INDEX
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def save_checkpoint(
step: int,
batch_size: int,
coordinator: DistCoordinator,
use_lora: bool = False,
) -> None:
"""
Save model checkpoint, optimizer, LR scheduler and intermedidate running states.
Expand All @@ -51,7 +52,10 @@ def save_checkpoint(
save_dir = os.path.join(save_dir, f"epoch-{epoch}_step-{step}")
os.makedirs(os.path.join(save_dir, "modeling"), exist_ok=True)

booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)
if use_lora:
booster.save_lora_as_pretrained(model, os.path.join(save_dir, "modeling"))
else:
booster.save_model(model, os.path.join(save_dir, "modeling"), shard=True)

booster.save_optimizer(optimizer, os.path.join(save_dir, "optimizer"), shard=True)
booster.save_lr_scheduler(lr_scheduler, os.path.join(save_dir, "lr_scheduler"))
Expand Down
115 changes: 61 additions & 54 deletions applications/Colossal-LLaMA/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from colossal_llama.utils.froze import freeze_non_embeds_parameters
from colossal_llama.utils.neftune_patch import activate_neftune, deactivate_neftune
from colossal_llama.utils.utils import all_reduce_mean, format_numel_str, get_model_numel
from peft import LoraConfig
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer
Expand Down Expand Up @@ -65,7 +66,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "gemini_auto":
Expand All @@ -75,7 +76,7 @@ def train(args) -> None:
initial_scale=2**16,
max_norm=args.grad_clip,
enable_gradient_accumulation=(args.accumulation_steps > 1),
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_flash_attention=args.use_flash_attn,
)
elif args.plugin == "zero2":
Expand All @@ -101,10 +102,9 @@ def train(args) -> None:
sequence_parallelism_mode=args.sp_mode,
zero_stage=args.zero_stage,
enable_flash_attention=args.use_flash_attn,
enable_fused_normalization=torch.cuda.is_available(),
enable_fused_normalization=get_accelerator().is_available(),
enable_sequence_parallelism=args.enable_sequence_parallelism,
cpu_offload=True if args.zero_stage >= 1 and args.zero_cpu_offload else False,
parallel_output=False,
max_norm=args.grad_clip,
precision=args.mixed_precision,
microbatch_size=args.microbatch_size,
Expand All @@ -117,11 +117,17 @@ def train(args) -> None:
# ======================================================
# Initialize Tokenizer, Dataset, Collator and Dataloader
# ======================================================
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
tokenizer = AutoTokenizer.from_pretrained(args.pretrained, trust_remote_code=True)
if args.pad_token == "eos":
tokenizer.pad_token = tokenizer.eos_token
try:
tokenizer.pad_token = tokenizer.eos_token
except AttributeError:
coordinator.print_on_master(f"pad_token can't be set")
elif args.pad_token == "unk":
tokenizer.pad_token = tokenizer.unk_token
try:
tokenizer.pad_token = tokenizer.unk_token
except AttributeError:
coordinator.print_on_master(f"pad_token can't be set")
tokenizer.add_bos_token = False
tokenizer.add_eos_token = False

Expand Down Expand Up @@ -164,33 +170,31 @@ def train(args) -> None:
# ======================================================
# Initialize Model, Objective, Optimizer and LR Scheduler
# ======================================================
# When training the ChatGLM model, LoRA and gradient checkpointing are incompatible.
init_ctx = (
LazyInitContext(default_device=get_current_device())
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin)) and args.lora_rank == 0
else nullcontext()
)
with init_ctx:
if args.use_flash_attn:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
else:
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
model = AutoModelForCausalLM.from_pretrained(
args.pretrained,
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
trust_remote_code=True,
)
# Freeze part of parameters.
if args.freeze_non_embeds_params:
freeze_non_embeds_parameters(model=model)

if args.lora_rank > 0:
lora_config = LoraConfig(task_type="CAUSAL_LM", r=args.lora_rank, lora_alpha=32, lora_dropout=0.1)
model = booster.enable_lora(model, lora_config=lora_config)

# this is essential, otherwise the grad checkpoint will not work.
model.train()

if args.use_grad_checkpoint:
model.gradient_checkpointing_enable()
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")

model_numel = get_model_numel(model)
Expand Down Expand Up @@ -327,6 +331,7 @@ def train(args) -> None:
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
Expand Down Expand Up @@ -371,44 +376,45 @@ def train(args) -> None:
total_loss.fill_(0.0)
pbar.update()

# Save modeling.
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)
# Save modeling.
save_model_condition = (
args.save_interval > 0 and (step + 1) % (args.save_interval * args.accumulation_steps) == 0
)

if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)
if not args.skip_save_each_epoch:
save_model_condition = save_model_condition or (step + 1) == len(dataloader)

if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")
if save_model_condition and not args.benchmark:
coordinator.print_on_master("\nStart saving model checkpoint with running states")

if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)
if args.use_neft:
coordinator.print_on_master("Deactivate NEFTune before saving model.")
deactivate_neftune(model, handle)

accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)
accelerator.empty_cache()
save_checkpoint(
save_dir=args.save_dir,
booster=booster,
model=model,
optimizer=optimizer,
lr_scheduler=lr_scheduler,
epoch=epoch,
step=step + 1,
batch_size=args.batch_size,
coordinator=coordinator,
use_lora=(args.lora_rank > 0),
)
coordinator.print_on_master(
f"Saved checkpoint at epoch {epoch} step {step + 1} at folder {args.save_dir}"
)

if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)
if args.use_neft:
coordinator.print_on_master("Activate NEFTune.")
model, handle = activate_neftune(model)

# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()
# Delete cache.
# del batch, batch_labels, batch_output, loss
accelerator.empty_cache()

# the continue epochs are not resumed, so we need to reset the sampler start index and start step
dataloader.sampler.set_start_index(start_index=0)
Expand Down Expand Up @@ -522,6 +528,7 @@ def train(args) -> None:
parser.add_argument(
"--microbatch_size", type=int, default=1, help="Batch size for each process in PP, used for 3d plugin."
)
parser.add_argument("--lora_rank", type=int, default=0, help="lora rank when using lora to train.")

# Additional arguments for benchmark.
parser.add_argument("--num_samples", type=int, default=500, help="Number of samples for benchmarking.")
Expand Down
6 changes: 3 additions & 3 deletions colossalai/lazy/lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,9 +509,9 @@ def wrap_factory_like_method(orig_target, target):
# factory_like functions (eg. torch.empty_like())
def wrapper(*args, **kwargs):
orig_t = args[0]
return self.tensor_cls(
orig_target, *orig_t.shape, *args[1:], device=orig_t.device, dtype=orig_t.dtype, **kwargs
)
device = kwargs.pop("device", orig_t.device)
dtype = kwargs.pop("dtype", orig_t.dtype)
return self.tensor_cls(orig_target, *orig_t.shape, *args[1:], device=device, dtype=dtype, **kwargs)

return wrapper, target

Expand Down
2 changes: 1 addition & 1 deletion colossalai/legacy/communication/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _communicate(
for req in reqs:
req.wait()
# To protect against race condition when using batch_isend_irecv().
torch.cuda.synchronize()
get_accelerator().synchronize()

if recv_prev and recv_prev_split:
if isinstance(tensor_recv_prev, torch.Tensor):
Expand Down
33 changes: 20 additions & 13 deletions colossalai/pipeline/p2p.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from torch.distributed import distributed_c10d as c10d
from torch.utils._pytree import tree_flatten, tree_unflatten

from colossalai.accelerator import get_accelerator

from .stage_manager import PipelineStageManager


Expand All @@ -31,7 +33,7 @@ def _cuda_safe_tensor_to_object(tensor: torch.Tensor, tensor_size: torch.Size) -
buf = tensor.numpy().tobytes()[:tensor_size]
if b"cuda" in buf:
buf_array = bytearray(buf)
device_index = torch.cuda.current_device()
device_index = get_accelerator().current_device()
# There might be more than one output tensors during forward
for cuda_str in re.finditer(b"cuda", buf_array):
pos = cuda_str.start()
Expand Down Expand Up @@ -86,7 +88,7 @@ def _broadcast_object_list(
else:
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
current_device = torch.device("cuda", get_accelerator().current_device())

my_rank = dist.get_rank()
# Serialize object_list elements to tensors on src rank.
Expand Down Expand Up @@ -139,29 +141,29 @@ def _broadcast_object_list(
# unconsistence in device
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != torch.cuda.current_device()
and unpickle_object.device.index != get_accelerator().current_device()
):
unpickle_object = unpickle_object.cuda()
unpickle_object = unpickle_object.to(get_accelerator().current_device())

object_list[i] = unpickle_object


def _check_for_nccl_backend(group):
def _check_for_nccl_hccl_backend(group):
pg = group or c10d._get_default_group()
# Gate PG wrapper check on Gloo availability.
if c10d._GLOO_AVAILABLE:
# It is not expected for PG to be wrapped many times, but support it just in case
while isinstance(pg, c10d._ProcessGroupWrapper):
pg = pg.wrapped_pg

return c10d.is_nccl_available() and pg.name() == c10d.Backend.NCCL
return (c10d.is_nccl_available() or torch.distributed.is_hccl_available()) and pg.name() == c10d.Backend.NCCL


def _check_device(group):
is_nccl_backend = _check_for_nccl_backend(group)
is_nccl_backend = _check_for_nccl_hccl_backend(group)
current_device = torch.device("cpu")
if is_nccl_backend:
current_device = torch.device("cuda", torch.cuda.current_device())
current_device = torch.device(get_accelerator().current_device())
return current_device, is_nccl_backend


Expand Down Expand Up @@ -348,8 +350,11 @@ def _send_recv_serialization_object(

unpickle_object = _cuda_safe_tensor_to_object(recv_object_tensor, recv_object_size_tensor.item())

if isinstance(unpickle_object, torch.Tensor) and unpickle_object.device.index != torch.cuda.current_device():
unpickle_object = unpickle_object.cuda()
if (
isinstance(unpickle_object, torch.Tensor)
and unpickle_object.device.index != get_accelerator().current_device()
):
unpickle_object = unpickle_object.to(get_accelerator().current_device())

return unpickle_object

Expand Down Expand Up @@ -474,9 +479,11 @@ def _p2p_comm(
recv_prev_shape = None

if tensor_send_next is not None:
send_next_shape = torch.tensor(tensor_send_next.size(), device=torch.cuda.current_device(), dtype=torch.int64)
send_next_shape = torch.tensor(
tensor_send_next.size(), device=get_accelerator().current_device(), dtype=torch.int64
)
if recv_prev:
recv_prev_shape = torch.empty((3), device=torch.cuda.current_device(), dtype=torch.int64)
recv_prev_shape = torch.empty((3), device=get_accelerator().current_device(), dtype=torch.int64)

ops = []
if send_next_shape is not None:
Expand All @@ -501,7 +508,7 @@ def _p2p_comm(
# send and recv data
tensor_recv_prev = None
if recv_prev:
tensor_recv_prev = torch.empty(recv_prev_shape, device=torch.cuda.current_device(), dtype=comm_dtype)
tensor_recv_prev = torch.empty(recv_prev_shape, device=get_accelerator().current_device(), dtype=comm_dtype)

ops = []
if tensor_send_next is not None:
Expand Down
3 changes: 1 addition & 2 deletions colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

import torch
import torch.cuda
import torch.distributed
from torch.nn import Module, ModuleList
from torch.utils._pytree import tree_map
Expand All @@ -18,7 +17,7 @@
from .base import PipelineSchedule


def _wait_p2p(wait_handles: List[torch.cuda.Event]) -> None:
def _wait_p2p(wait_handles) -> None:
if wait_handles is not None:
for req in wait_handles:
req.wait()
Expand Down
1 change: 0 additions & 1 deletion colossalai/pipeline/schedule/one_f_one_b.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from typing import Any, Callable, Dict, Iterable, List, Optional, Union

import torch
import torch.cuda
from torch.nn import Module
from torch.utils._pytree import tree_map

Expand Down
Loading
Loading