diff --git a/applications/ColossalChat/examples/training_scripts/lora_finetune.py b/applications/ColossalChat/examples/training_scripts/lora_finetune.py index 851ad6a2d9e3..4045556d7ece 100644 --- a/applications/ColossalChat/examples/training_scripts/lora_finetune.py +++ b/applications/ColossalChat/examples/training_scripts/lora_finetune.py @@ -257,7 +257,7 @@ def is_master(): ) torch.set_default_dtype(torch.float) - booster.load_model(model, args.pretrained) + booster.load_model(model, args.pretrained, low_cpu_mem_mode=False, num_threads=8) coordinator.print_on_master( f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB" diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 4b1224c68ffd..a81f9b05d7d7 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -85,11 +85,11 @@ def save_unsharded_model( if use_async: from colossalai.utils.safetensors import save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict) for k, v in state_dict.items(): - self.pinned_state_dicts[id(model)][k].copy_(v) - state_dict[k] = self.pinned_state_dicts[id(model)][k] + self.pinned_state_dicts[hash(model)][k].copy_(v) + state_dict[k] = self.pinned_state_dicts[hash(model)][k] writer = save(checkpoint, state_dict) self.async_writers.append(writer) else: @@ -172,9 +172,9 @@ def save_sharded_model( Path(checkpoint_path).mkdir(parents=True, exist_ok=True) if use_async and self.coordinator.is_master(): - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = {} - pinned_state_dicts = self.pinned_state_dicts[id(model)] + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[hash(model)] else: pinned_state_dicts = None state_dict_shard = model.state_dict_shard( diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index 1684fd702e70..1e0f7be240f6 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -26,6 +26,7 @@ from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO from colossalai.cluster import ProcessGroupMesh from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper +from colossalai.interface.model import PeftUnwrapMixin from colossalai.interface.optimizer import DistributedOptim from colossalai.logging import get_dist_logger from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed @@ -225,7 +226,7 @@ def unwrap(self, unwrap_peft: bool = True): if isinstance(model, DDP): model = model.module if unwrap_peft and isinstance(model, PeftModel): - model = model.get_base_model() + model = PeftUnwrapMixin(model) return model def _force_wait_all_gather(self): diff --git a/colossalai/booster/plugin/torch_ddp_plugin.py b/colossalai/booster/plugin/torch_ddp_plugin.py index e74b1a9598b9..9cb5adf01972 100644 --- a/colossalai/booster/plugin/torch_ddp_plugin.py +++ b/colossalai/booster/plugin/torch_ddp_plugin.py @@ -12,6 +12,7 @@ from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO from colossalai.cluster import DistCoordinator from colossalai.interface import ModelWrapper, OptimizerWrapper +from colossalai.interface.model import PeftUnwrapMixin from colossalai.logging import get_dist_logger from colossalai.quantization import BnbQuantizationConfig, quantize_model from colossalai.utils import get_current_device @@ -201,7 +202,7 @@ def __init__(self, module: nn.Module, *args, **kwargs) -> None: def unwrap(self, unwrap_peft: bool = True) -> nn.Module: model = self.module.module if unwrap_peft and isinstance(model, PeftModel): - model = model.get_base_model() + model = PeftUnwrapMixin(model) return model diff --git a/colossalai/booster/plugin/torch_fsdp_plugin.py b/colossalai/booster/plugin/torch_fsdp_plugin.py index d713203fe905..6e652e549b5e 100644 --- a/colossalai/booster/plugin/torch_fsdp_plugin.py +++ b/colossalai/booster/plugin/torch_fsdp_plugin.py @@ -103,11 +103,11 @@ def save_unsharded_model( if use_async: from colossalai.utils.safetensors import save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(full_model_state) for k, v in full_model_state.items(): - self.pinned_state_dicts[id(model)][k].copy_(v) - full_model_state[k] = self.pinned_state_dicts[id(model)][k] + self.pinned_state_dicts[hash(model)][k].copy_(v) + full_model_state[k] = self.pinned_state_dicts[hash(model)][k] writer = save(checkpoint, full_model_state) self.async_writers.append(writer) else: @@ -186,9 +186,9 @@ def save_sharded_model( state_dict = model.unwrap().state_dict() if use_async and self.coordinator.is_master(): - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = {} - pinned_state_dicts = self.pinned_state_dicts[id(model)] + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[hash(model)] else: pinned_state_dicts = None state_dict_shard = utils.shard_model_checkpoint( diff --git a/colossalai/checkpoint_io/general_checkpoint_io.py b/colossalai/checkpoint_io/general_checkpoint_io.py index 3e600c94dfc5..5dfb09248b53 100644 --- a/colossalai/checkpoint_io/general_checkpoint_io.py +++ b/colossalai/checkpoint_io/general_checkpoint_io.py @@ -60,9 +60,9 @@ def save_unsharded_model( if use_async: from colossalai.utils.safetensors import move_and_save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) - writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)]) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict) + writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[hash(model)]) self.async_writers.append(writer) else: # save the checkpoint @@ -234,7 +234,7 @@ def save_sharded_model( index_file = CheckpointIndexFile(checkpoint_path) if use_async: - pinned_state_dict = self.pinned_state_dicts.get(id(model), None) + pinned_state_dict = self.pinned_state_dicts.get(hash(model), None) total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards( sharded_state_dict=state_dict_shard, checkpoint=checkpoint_path, @@ -243,7 +243,7 @@ def save_sharded_model( is_master=True, pinned_state_dict=pinned_state_dict, ) - self.pinned_state_dicts[id(model)] = new_pinned_state_dict + self.pinned_state_dicts[hash(model)] = new_pinned_state_dict self.async_writers.extend(writers) else: # Save shards of optimizer states. diff --git a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py index 5de32e66655c..9d972635214d 100644 --- a/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py +++ b/colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py @@ -249,9 +249,9 @@ def save_sharded_model( # Only devices with tp_rank == 0 are responsible for model saving. control_saving = self.tp_rank == 0 and self.sp_rank == 0 if control_saving and use_async: - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = {} - pinned_state_dicts = self.pinned_state_dicts[id(model)] + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = {} + pinned_state_dicts = self.pinned_state_dicts[hash(model)] else: pinned_state_dicts = None state_dict_shard = HybridParallelCheckpointIO._model_sharder( @@ -789,11 +789,11 @@ def save_unsharded_model( if use_async: from colossalai.utils.safetensors import save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict) for name, param in state_dict.items(): - self.pinned_state_dicts[id(model)][name].copy_(param) - state_dict[name] = self.pinned_state_dicts[id(model)][name] + self.pinned_state_dicts[hash(model)][name].copy_(param) + state_dict[name] = self.pinned_state_dicts[hash(model)][name] writer = save(path=checkpoint, state_dict=state_dict) self.async_writers.append(writer) else: @@ -811,11 +811,11 @@ def save_unsharded_model( if use_async: from colossalai.utils.safetensors import save - if id(model) not in self.pinned_state_dicts: - self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict) + if hash(model) not in self.pinned_state_dicts: + self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(complete_state_dict) for name, param in complete_state_dict.items(): - self.pinned_state_dicts[id(model)][name].copy_(param) - complete_state_dict[name] = self.pinned_state_dicts[id(model)][name] + self.pinned_state_dicts[hash(model)][name].copy_(param) + complete_state_dict[name] = self.pinned_state_dicts[hash(model)][name] writer = save(path=checkpoint, state_dict=complete_state_dict) self.async_writers.append(writer) else: diff --git a/colossalai/checkpoint_io/moe_checkpoint.py b/colossalai/checkpoint_io/moe_checkpoint.py index 586c7863f4bf..85e36f7c6336 100644 --- a/colossalai/checkpoint_io/moe_checkpoint.py +++ b/colossalai/checkpoint_io/moe_checkpoint.py @@ -701,15 +701,18 @@ def pre_save_model(self, model: nn.Module) -> dict: all_param = None # gather param from every ep rank # dist.all_gather(all_param, param, group=ep_group) - dist.gather(param, all_param, group=ep_group) + dist.gather(param, all_param, dst=dist.get_global_rank(ep_group, 0), group=ep_group) if ep_rank == 0: all_param = torch.cat(all_param, dim=0) state_dict[name] = all_param.cpu() if self.pp_size > 1: if self.dp_rank == 0: - out = [None for _ in range(self.pp_size)] - dist.gather_object(state_dict, out, group=self.pp_group) + if self.pp_rank == 0: + out = [None for _ in range(self.pp_size)] + else: + out = None + dist.gather_object(state_dict, out, dst=dist.get_global_rank(self.pp_group, 0), group=self.pp_group) if self.pp_rank == 0: new_state_dict = {} for o in out: diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index 2d826bd15f52..4b36dbe002bb 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -20,6 +20,7 @@ from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten from colossalai.accelerator import get_accelerator +from colossalai.interface.model import PeftUnwrapMixin from colossalai.tensor.d_tensor import ( is_customized_distributed_tensor, is_distributed_tensor, @@ -554,6 +555,8 @@ def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = T from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model except ImportError: return + if isinstance(model, PeftUnwrapMixin): + model = model.base_model if not isinstance(model, PreTrainedModel): return @@ -692,6 +695,9 @@ def load_state_dict_into_model( state_dict (dict): a dict containing parameters and persistent buffers. """ + if isinstance(model, PeftUnwrapMixin): + state_dict = model.patch_state_dict(state_dict) + model = model.base_model if not isinstance(state_dict, Mapping): raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict))) diff --git a/colossalai/interface/model.py b/colossalai/interface/model.py index d112c27230b0..8dbd15c63cf9 100644 --- a/colossalai/interface/model.py +++ b/colossalai/interface/model.py @@ -1,5 +1,102 @@ +import re +from typing import Dict, Set + +import torch import torch.nn as nn -from peft import PeftModel +from peft import PeftModel, PeftType + + +def extract_lora_layers(model: PeftModel, names: Set[str], adapter_name: str = "default"): + config = model.peft_config[adapter_name] + if config.peft_type != PeftType.LORA: + raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.") + # to_return = lora_state_dict(model, bias=model.peft_config.bias) + # adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py` + # to be used directly with the state dict which is necessary when using DeepSpeed or FSDP + bias = config.bias + if bias == "none": + to_return = {k for k in names if "lora_" in k} + elif bias == "all": + to_return = {k for k in names if "lora_" in k or "bias" in k} + elif bias == "lora_only": + to_return = set() + for k in names: + if "lora_" in k: + to_return.add(k) + bias_name = k.split("lora_")[0] + "bias" + if bias_name in names: + to_return.add(bias_name) + else: + raise NotImplementedError + to_return = {k for k in to_return if (("lora_" in k and adapter_name in k) or ("bias" in k))} + if config.use_dora: + # Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a + # ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since + # we want the state_dict format not to change, we remove the "weight" part. + new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight" + + def renamed_dora_weights(k): + if k.endswith(new_dora_suffix): + k = k[:-7] # remove ".weight" + return k + + to_return = {renamed_dora_weights(k) for k in to_return} + + to_return = {re.sub(f"lora_\S\.{adapter_name}\.(weight|bias)", "base_layer", k) for k in to_return} + return to_return + + +class PeftUnwrapMixin: + def __init__(self, peft_model: PeftModel): + self.base_model = peft_model.get_base_model() + # peft does not affect buffers + self.lora_layers = extract_lora_layers(peft_model, set(n for n, p in self.base_model.named_parameters())) + potential_lora_weights = set() + for n in self.lora_layers: + potential_lora_weights.add(f"{n}.weight") + potential_lora_weights.add(f"{n}.bias") + self.lora_param_to_origin_param = {n: n.replace("base_layer.", "") for n in potential_lora_weights} + self.origin_param_to_lora_param = {v: k for k, v in self.lora_param_to_origin_param.items()} + + def named_parameters(self): + for n, p in self.base_model.named_parameters(): + if n in self.lora_param_to_origin_param: + n = self.lora_param_to_origin_param[n] + yield n, p + + def named_buffers(self): + return self.base_model.named_buffers() + + @property + def _modules(self): + return self.base_model._modules + + @property + def _non_persistent_buffers_set(self): + return self.base_model._non_persistent_buffers_set + + def patch_state_dict(self, state_dict: Dict[str, torch.Tensor]): + new_state_dict = {} + for k, v in state_dict.items(): + if k in self.origin_param_to_lora_param: + k = self.origin_param_to_lora_param[k] + new_state_dict[k] = v + return new_state_dict + + def state_dict(self): + state_dict = {} + for k, v in self.base_model.state_dict().items(): + if k in self.lora_param_to_origin_param: + k = self.lora_param_to_origin_param[k] + state_dict[k] = v + return state_dict + + def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False): + state_dict = self.patch_state_dict(state_dict) + self.base_model.load_state_dict(state_dict, strict=strict, assign=assign) + + def __hash__(self): + return hash(self.base_model) class ModelWrapper(nn.Module): @@ -23,7 +120,7 @@ def unwrap(self, unwrap_peft: bool = True): else: model = self.module if unwrap_peft and isinstance(model, PeftModel): - model = model.get_base_model() + model = PeftUnwrapMixin(model) return model def forward(self, *args, **kwargs): diff --git a/colossalai/pipeline/schedule/dualpipe_schedule.py b/colossalai/pipeline/schedule/dualpipe_schedule.py new file mode 100644 index 000000000000..c19d800c989e --- /dev/null +++ b/colossalai/pipeline/schedule/dualpipe_schedule.py @@ -0,0 +1,1490 @@ +from math import ceil, floor +from typing import List + +from .v_schedule import ScheduledNode + +DUALPIPE_NODETYPE = {"F", "B", "W", "Full_B", "EMPTY_BUBBLE"} + + +class DualPipeGraph(object): + """DualPipeGraph + We brokendown DualPipe to three Pipe_Stage: Warmup, Middle, End + Warmup contains: + step1: no_cross_fwd + step2: cross_fwd + step3: warmup_1F1B1W + step4: warmup_transitions + Middle contains: (named as their shape in pipe) + step1: mid_rhombic + step2: mid_butterfly + step3: mid_transitions + End contains: + step1: bwdB_step + step2: cross_bwdB_bwdW + step3: bwdW_step + """ + + def __init__( + self, + n_stage, + n_micro, + f_cost: int = 1, + b_cost: int = 1, + w_cost: int = 1, + c_cost: int = 1, + f_mem: int = 1, + b_mem: int = 1, + w_mem: int = 1, + max_mem: int = None, + ): + self.n_node = 6 * n_stage * n_micro + self.n_stage = n_stage + self.n_micro = n_micro + self.f_cost = f_cost + self.b_cost = b_cost + self.w_cost = w_cost + self.c_cost = c_cost + self.f_mem = f_mem + self.b_mem = b_mem + self.w_mem = w_mem + self.fbw_cost = [f_cost, b_cost, w_cost] + self.fbw_mem = [f_mem, b_mem, w_mem] + self.max_mem = max_mem or f_mem * self.n_stage * 2 + + # time unit + self.one_time_unit = 1 + # total mbs (both up and down) + self.total_mbs = (n_stage + 2) * 2 + # one side mbs + self.mbs = self.total_mbs // 2 + + def print_details( + self, + pipeline_schedule: List[List[ScheduledNode]], + chunk_mode: bool = False, + mbs_mode: bool = False, + empty_bubble_str_mode: bool = False, + ): + assert not ( + chunk_mode and mbs_mode + ), "Only one mode is supported at the same time, please choose from chunk_mode and mbs_mode" + schedule_str = "" + for stage in range(self.n_stage): + stage_nodes = [] + for node in pipeline_schedule[stage]: + if node.type in DUALPIPE_NODETYPE: + if node.type == "EMPTY_BUBBLE": + if empty_bubble_str_mode: + stage_nodes.append("E") + else: + stage_nodes.append(" ") + else: + if chunk_mode: + stage_nodes.append(node.type + str(node.chunk)) + elif mbs_mode: + stage_nodes.append(node.type + str(node.minibatch)) + else: + stage_nodes.append(node.type) + stage_str = "".join([_ for _ in stage_nodes]) + schedule_str += "\n" + stage_str + print(schedule_str) + + def get_pipe_first_b_w(self, stage_pipe: List[ScheduledNode], chunk: int = 0): + # get first d, last d, first u, last u B node in range[first B, first W] + first_d, last_d, first_u, last_u = self.n_micro // 2, 0, self.n_micro // 2, 0 + stage_pipe_temp = [] + for node in stage_pipe[::-1]: + if node.type == "Full_B": + break + else: + stage_pipe_temp.append(node) + stage_pipe = stage_pipe_temp[::-1] # node from last fully B to ... + # print(f"stage_pipe {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in stage_pipe]}") + if chunk == 0: + # get first d + for node in stage_pipe: + if node.type == "B" and node.chunk == 1: + first_d = node.minibatch + break + + # get first u + for node in stage_pipe: + if node.type == "B" and node.chunk == 0: + first_u = node.minibatch + break + + # get last_d + for node in stage_pipe[::-1]: + if node.type == "B" and node.chunk == 1: + last_d = node.minibatch + break + + # get last_u + for node in stage_pipe[::-1]: + if node.type == "B" and node.chunk == 0: + last_u = node.minibatch + break + else: + # get first d + for node in stage_pipe: + if node.type == "B" and node.chunk == 1: + first_d = node.minibatch + break + + # get first u + for node in stage_pipe: + if node.type == "B" and node.chunk == 0: + first_u = node.minibatch + break + + # get last_d + for node in stage_pipe[::-1]: + if node.type == "B" and node.chunk == 1: + last_d = node.minibatch + break + + # get last_u + for node in stage_pipe[::-1]: + if node.type == "B" and node.chunk == 0: + last_u = node.minibatch + break + return first_d, last_d, first_u, last_u + + def cross_merge_nodes( + self, node_list1: List[ScheduledNode], node_list2: List[ScheduledNode] + ) -> List[ScheduledNode]: + """ + corss merge node in Step: get_end_schedule-->cross_bwdB_bwdW + example 1: + inputs: + node_list1:[Node 1, Node 3, Node 5] + node_list2:[Node 2, Node 4, Node 6] + return: + node_list3:[Node 1, Node 2, Node 3, Node 4, Node 5, Node 6] + + example 2: + inputs: + node_list1:[Node 1, Node 3, Node 5] + node_list2:[Node 2,] + return: + node_list3:[Node 1, Node 2, Node 3, Node 5] + """ + merged = [] + for x, y in zip(node_list1, node_list2): + merged.append(x) + merged.append(y) + merged += node_list1[len(node_list2) :] # deal list1 rest ele + merged += node_list2[len(node_list1) :] # deal list2 rest ele + return merged + + ################ + # Pipe_Stage 1 + ################ + def get_warmup_schedule(self, pipeline_schedule: List[List[ScheduledNode]]): + ########### Pipe_Stage 1.1 ########### + def no_cross_fwd(pipeline_schedule: List[List[ScheduledNode]]): + # stage [0,pp/2) + for stage in range(0, self.n_stage // 2): + # add num i empty bubble + start_time = 0 + for i in range(0, stage): + pipeline_schedule[stage].append( + ScheduledNode( + type="EMPTY_BUBBLE", + chunk=0, + stage=stage, + minibatch=0, + start_time=start_time, + completion_time=start_time + self.one_time_unit, + ) + ) + start_time += self.one_time_unit + # add FWD node + # Stage i in [0, pp/2) mbs m in range [0, (pp - 1) - 2i) model chunk 0 fwd + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + for i in range(0, (self.n_stage - 1) - 2 * stage): + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=0, + stage=stage, + minibatch=i, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + # stage [pp/2,n_stage) + for stage in range(self.n_stage // 2, self.n_stage): + start_time = 0 + for i in range(0, self.n_stage - stage - 1): + # add num i empty bubble + pipeline_schedule[stage].append( + ScheduledNode( + type="EMPTY_BUBBLE", + chunk=0, + stage=stage, + minibatch=0, + start_time=start_time, + completion_time=start_time + self.one_time_unit, + ) + ) + start_time += self.one_time_unit + # add FWD node + # Stage i in [pp/2, pp) , mbs m in range [0, 2i - (pp-1)) model chunk 1 fwd + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + for i in range(0, 2 * stage - (self.n_stage - 1)): + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=1, + stage=stage, + minibatch=i, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ########### Pipe_Stage 1.2 ########### + # For each stage, add schedule Nodes col pp/2 times (range 0 to self.n_stage//2 + 1), + def cross_fwd(pipeline_schedule: List[List[ScheduledNode]]): + for r in range(0, self.n_stage // 2 + 1): + if r == 0: + # special case; only one col 0d 0u + for stage in range(r, self.n_stage // 2): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=1, + stage=stage, + minibatch=0, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + for stage in range(self.n_stage // 2, self.n_stage - r): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=0, + stage=stage, + minibatch=0, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + else: + if r % 2 != 0: + for stage in range(r, self.n_stage // 2): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + # Stage i in [r, pp/2) , mbs为 (pp - 1) - 2i + (r-1) model chunk 0 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=0, + stage=stage, + minibatch=(self.n_stage - 1) - 2 * stage + (r - 1), + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + # Stage i in [r, pp/2) , mbs r model chunk 1 fwd # 全1d + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=1, + stage=stage, + minibatch=r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + for stage in range(self.n_stage // 2, self.n_stage - r): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + # Stage i in [pp/2, pp-r), mbs 为2i - (pp-1) + (r-1) model chunk 1 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=1, + stage=stage, + minibatch=2 * stage - (self.n_stage - 1) + (r - 1), + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + # Stage i in [pp/2, pp-r) 压入mbs r model chunk 0 fwd # 全1u + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=0, + stage=stage, + minibatch=r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + if r % 2 == 0: + for stage in range(r, self.n_stage // 2): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + # Stage i in [r, pp/2) , mbs为 (pp - 1) - 2i + (r-1) model chunk 0 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=0, + stage=stage, + minibatch=self.n_stage - 2 * stage + (r - 2), + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + # Stage i in [r, pp/2) , mbs r model chunk 1 fwd # 全1d + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=1, + stage=stage, + minibatch=r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + for stage in range(self.n_stage // 2, self.n_stage - r): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + # Stage i in [pp/2, pp-r), mbs 为2i - (pp-1) + (r-1) model chunk 1 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=1, + stage=stage, + minibatch=2 * stage - (self.n_stage - 2) + (r - 2), + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + # Stage i in [pp/2, pp-r) 压入mbs r model chunk 0 fwd # 全1u + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=0, + stage=stage, + minibatch=r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ########### Pipe_Stage 1.3 ########### + def warmup_1f1b1w(pipeline_schedule: List[List[ScheduledNode]]): + # for each stage, add Schedule Nodes pp/2 times from (0, self.n_stage//2 + 1) + for r in range(0, self.n_stage // 2): + # [0, pp/2 - r) + for stage in range(0, self.n_stage // 2 - 1 - r): + ###### bwd b ###### + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + # Stage i in [r, pp/2) , mbs为 (pp - 1) - 2i + (r-1) model chunk 0 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="B", + chunk=1, + stage=stage, + minibatch=r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + ###### bwd w ###### + # Stage i in [0, pp/2 - r) , mbs r model chunk 1 bwd w + pipeline_schedule[stage].append( + ScheduledNode( + type="W", + chunk=1, + stage=stage, + minibatch=r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + ###### fwd ###### + # Stage i in [0, pp/2 - r), mbs i + 1 + r model chunk 1 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=1, + stage=stage, + minibatch=stage + 1 + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + # [pp/2 + 1 + r, pp) + for stage in range(self.n_stage // 2 + 1 + r, self.n_stage): + ###### bwd b ###### + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + # Stage i in [0, pp/2 - r), mbs i + 1 + r model chunk 1 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="B", + chunk=0, + stage=stage, + minibatch=r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + ###### bwd w ###### + # Stage i in [pp/2 + 1 - r, pp) ,mbs r model chunk 0 bwd w + pipeline_schedule[stage].append( + ScheduledNode( + type="W", + chunk=0, + stage=stage, + minibatch=r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + ###### fwd ###### + # Stage i in [0, pp/2 - r) , mbs i + 1 + r model chunk 1 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=0, + stage=stage, + minibatch=self.n_stage - stage + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ########### Pipe_Stage 1.4 ########### + def warmup_transitions(pipeline_schedule: List[List[ScheduledNode]]): + # For each stage, add pp/2 - 1 round Schedule Nodes + for r in range(0, self.n_stage // 2): + if r == 0: + # special round add 1 col fwd and 1 col fullyBwd + for stage in range(0, self.n_stage // 2): + ###### Fwd ###### + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + # Stage i in [0, pp/2 ) , mbs pp - i - 1 - r model chunk 0 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=0, + stage=stage, + minibatch=self.n_stage - stage - 1 - r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + ###### Fully Bwd ###### + # Stage i in [0, pp/2) , add mbs (pp/2) - i - 1 - r model chunk 0 Full_B + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=0, + stage=stage, + minibatch=self.n_stage // 2 - stage - 1 - r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + for stage in range(self.n_stage // 2, self.n_stage): + ###### Fwd ###### + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + # Stage i in [pp/2, pp), mbs i + r model chunk chunk 1 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=1, + stage=stage, + minibatch=stage + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + ###### Fully Bwd ###### + # Stage i in [pp/2, pp), mbs i - (pp/2) + r model chunk chunk 1 FullyB + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=1, + stage=stage, + minibatch=stage - self.n_stage // 2 + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + else: + for stage in range(0, self.n_stage // 2 - r): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + # Stage i in [0, pp/2 - r) EMPTY_BUBBLE + pipeline_schedule[stage].append( + ScheduledNode( + type="EMPTY_BUBBLE", + chunk=0, + stage=stage, + minibatch=0, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + for stage in range(self.n_stage // 2 + r, self.n_stage): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + # Stage i in [pp/2 + r, pp) 压入空泡 + pipeline_schedule[stage].append( + ScheduledNode( + type="EMPTY_BUBBLE", + chunk=0, + stage=stage, + minibatch=0, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + no_cross_fwd(pipeline_schedule) + cross_fwd(pipeline_schedule) + warmup_1f1b1w(pipeline_schedule) + warmup_transitions(pipeline_schedule) + + ################ + # Pipe_Stage 2 + ################ + def get_middle_schedule(self, pipeline_schedule: List[List[ScheduledNode]]): + ########### Pipe_Stage 2.1 ########### + def mid_rhombic(pipeline_schedule: List[List[ScheduledNode]]): + # for each stage, add (pp/2) + 1 round(total 9 round)Schedule Nodes + for r in range(0, self.n_stage // 2 + 1): + if r == 0: + for stage in range(0, self.n_stage // 2): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + ###### Fwd 1 ###### + # Stage i in [0, pp/2 ), mbs pp/2 model chunk 1 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=1, + stage=stage, + minibatch=self.n_stage // 2 + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + ###### Fully Bwd 1 ###### + # Stage i in [0, pp/2) , mbs r model chunk 0 Fully Bwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=0, + stage=stage, + minibatch=r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fwd 2 ###### + # Stage i in [0, pp/2 ) , mbs pp - i -r model chunk 0 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=0, + stage=stage, + minibatch=self.n_stage - stage - r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fully Bwd 2 ###### + # Stage i in [0, pp/2) , mbs (pp/2) - i - r model chunk 1 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=1, + stage=stage, + minibatch=self.n_stage // 2 - stage - r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + for stage in range(self.n_stage // 2, self.n_stage): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + ###### Fwd 1 ###### + # Stage i in [pp/2, pp), mbs pp/2 + r model chunk chunk 0 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=0, + stage=stage, + minibatch=self.n_stage // 2 + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + ###### Fully Bwd 1 ###### + # Stage i in [0, pp/2) , mbs r model chunk 0 Fully Bwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=1, + stage=stage, + minibatch=r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fwd 2 ###### + # Stage i in [pp/2, pp) , mbs i + 1 + r model chunk 1 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=1, + stage=stage, + minibatch=stage + 1 + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + ###### Fully Bwd 2 ###### + # Stage i in [pp/2, pp) , mbs i - ((pp/2) -1) + r model chunk chunk 0 Fully Bwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=0, + stage=stage, + minibatch=stage - (self.n_stage // 2 - 1) + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + else: + for stage in range(r - 1, self.n_stage // 2): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + ###### Fwd 1 ###### + # Stage i in [r - 1, pp/2), mbs (pp/2) + r model chunk 1 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=1, + stage=stage, + minibatch=self.n_stage // 2 + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fully Bwd 1 ###### + # Stage i in [r - 1, pp/2), mbs r model chunk 0 Fully Bwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=0, + stage=stage, + minibatch=r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fwd 2 ###### + # Stage i in [r - 1, pp/2) , mbs pp + r - i model chunk 0 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=0, + stage=stage, + minibatch=self.n_stage + r - stage, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fully Bwd 2 ###### + # Stage i in [r - 1, pp/2), mbs (pp/2) - i + r model chunk 1 Fully Bwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=1, + stage=stage, + minibatch=self.n_stage // 2 - stage + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + for stage in range(self.n_stage // 2, self.n_stage - r + 1): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + ###### Fwd 1 ###### + # Stage i in [pp/2, pp - r + 1) , mbs (pp/2) + r model chunk chunk 0 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=0, + stage=stage, + minibatch=self.n_stage // 2 + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fully Bwd 1 ###### + # Stage i in [pp/2, pp - r + 1), mbs r model chunk chunk 1 Fully Bwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=1, + stage=stage, + minibatch=r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fwd 2 ###### + # Stage i in [pp/2, pp - r + 1), mbs i + 1 + r model chunk chunk 1 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=1, + stage=stage, + minibatch=stage + 1 + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + ###### Fully Bwd 2 ###### + # Stage i in [pp/2, pp - r + 1), mbs i - ((pp/2) -1) + r model chunk chunk 0 Fully Bwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=0, + stage=stage, + minibatch=stage - (self.n_stage // 2 - 1) + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ########### Pipe_Stage 2.2 ########### + def mid_butterfly(pipeline_schedule: List[List[ScheduledNode]]): + # for each stage, add pp/2 round(total 8 round)Schedule Nodes + for r in range(0, self.n_stage // 2 + 1): + if r == 0: + for stage in range(0, self.n_stage // 2 - r): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + ###### Fwd ###### + # Stage i in [0, pp/2 - r ), mbs bs//2-pp//2+i+r model chunk 1 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=1, + stage=stage, + minibatch=self.n_micro // 2 - self.n_stage // 2 + stage + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + for stage in range(self.n_stage // 2 + r, self.n_stage): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + ###### Fwd ###### + # Stage i in [pp/2 + r, pp), mbs (bs//2 + pp//2 - 1)- i + r model chunk chunk 0 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=0, + stage=stage, + minibatch=self.n_micro // 2 + self.n_stage // 2 - 1 - stage + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + else: + for stage in range(0, self.n_stage // 2 - r): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + ###### Fully bwd 1 ###### + # Stage i in [0, pp/2 - r ), mbs i + r + 1 model chunk 0 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=0, + stage=stage, + minibatch=stage + r + 1, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + ###### Empty Bubble ###### + # Stage i in [0, pp/2 - r) , Empty Bubble + pipeline_schedule[stage].append( + ScheduledNode( + type="EMPTY_BUBBLE", + chunk=0, + stage=stage, + minibatch=0, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fully bwd 2 ###### + # Stage i in [0, pp/2 - r), mbs bs//2-pp//2 - 1 + r model chunk 1 Fully bwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=1, + stage=stage, + minibatch=self.n_micro // 2 - self.n_stage // 2 - 1 + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fwd ###### + # Stage i in [0, pp/2 - r), mbs bs//2 - pp//2 + i + r model chunk 1 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=1, + stage=stage, + minibatch=self.n_micro // 2 - self.n_stage // 2 + stage + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + for stage in range(self.n_stage // 2 + r, self.n_stage): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + ###### Fully bwd 1 ###### + # Stage i in [pp/2 + r, pp), mbs bs - i + r -2 model chunk 1 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=1, + stage=stage, + minibatch=self.n_stage - stage + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Empty Bubble ###### + # Stage i in [pp/2 + r, pp), Empty Bubble + pipeline_schedule[stage].append( + ScheduledNode( + type="EMPTY_BUBBLE", + chunk=0, + stage=stage, + minibatch=0, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + ###### Fully bwd 2 ###### + # Stage i in [pp/2 + r, pp), mbs bs//2-pp//2 - 1 + r model chunk 0 Fully bwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=1, + stage=stage, + minibatch=self.n_micro // 2 - self.n_stage // 2 - 1 + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fwd ###### + # Stage i in [pp/2 + r, pp), mbs (bs/2+pp/2 - 1) - i + r model chunk 0 fwd + pipeline_schedule[stage].append( + ScheduledNode( + type="F", + chunk=0, + stage=stage, + minibatch=self.n_micro // 2 + self.n_stage // 2 - 1 - stage + r, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ########### Pipe_Stage 2.3 ########### + def mid_transitions(pipeline_schedule: List[List[ScheduledNode]]): + # for each stage, add pp/2 + 1 round(total 9 round)Schedule Nodes + for r in range(0, self.n_stage // 2 + 1): + if r == 0: + for stage in range(r, self.n_stage // 2): + ###### Fully B ###### + # Stage i in [r, pp/2), mbs pp/2 + r + 1 model chunk 0 fully bwd + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=0, + stage=stage, + minibatch=self.n_stage // 2 + r + 1, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + for stage in range(self.n_stage // 2, self.n_stage): + ###### Fully B ###### + # Stage i in [pp/2, pp), mbs pp/2 + r + 1 model chunk 1 fully bwd + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=1, + stage=stage, + minibatch=self.n_stage // 2 + r + 1, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + else: + if r % 2 != 0: # odd round: 1, 3, 5, 7 + for stage in range(r - 1, self.n_stage // 2): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + ###### Empty Bubble ###### + # Stage i in [r - 1, pp/2), Empty bubble + pipeline_schedule[stage].append( + ScheduledNode( + type="EMPTY_BUBBLE", + chunk=0, + stage=stage, + minibatch=0, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fully bwd ###### + # Stage i in [r - 1, pp/2), mbs pp + r - i model chunk 1 fully bwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=1, + stage=stage, + minibatch=self.n_stage + ceil(r / 2) - stage, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + for stage in range(self.n_stage // 2, self.n_stage - r + 1): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + ###### Empty Bubble ###### + # Stage i in [pp/2, pp - r + 1), Empty bubble + pipeline_schedule[stage].append( + ScheduledNode( + type="EMPTY_BUBBLE", + chunk=0, + stage=stage, + minibatch=0, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fully bwd ###### + # Stage i in [pp/2, pp - r + 1) 压入mbs i + 1 + r model chunk chunk 0 fully bwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=0, + stage=stage, + minibatch=stage + 1 + ceil(r / 2), + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + else: # even round: 2, 4, 6, 8 + for stage in range(r - 1, self.n_stage // 2): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + ###### Empty Bubble ###### + pipeline_schedule[stage].append( + ScheduledNode( + type="EMPTY_BUBBLE", + chunk=0, + stage=stage, + minibatch=0, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fully bwd ###### + # Stage i in [r - 1, pp/2), mbs pp//2 + r model chunk 0 fully bwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=0, + stage=stage, + minibatch=self.n_stage // 2 + floor(r / 2) + 1, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + for stage in range(self.n_stage // 2, self.n_stage - r + 1): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + ###### Empty Bubble ###### + pipeline_schedule[stage].append( + ScheduledNode( + type="EMPTY_BUBBLE", + chunk=0, + stage=stage, + minibatch=0, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ###### Fully bwd ###### + # Stage i in [pp/2, pp - r + 1), mbs pp//2 + r model chunk chunk 1 fully bwd + pipeline_schedule[stage].append( + ScheduledNode( + type="Full_B", + chunk=1, + stage=stage, + minibatch=self.n_stage // 2 + floor(r / 2) + 1, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + mid_rhombic(pipeline_schedule) + mid_butterfly(pipeline_schedule) + mid_transitions(pipeline_schedule) + + ################ + # Pipe_Stage 3 + ################ + def get_end_schedule(self, pipeline_schedule: List[List[ScheduledNode]]): + ########### Pipe_Stage 3.1 ########### + def bwdB_step(pipeline_schedule: List[List[ScheduledNode]]): + # for each stage, pp/2 round(total 8 round)Schedule Nodes, + for r in range(0, self.n_stage // 2): + if r % 2 == 0: + # Stage i in [pp/2 - r - 1, pp/2) + for stage in range(self.n_stage // 2 - r - 1, self.n_stage // 2): + # Stage i in [pp/2 - r - 1, pp/2), mbs (pp/2 - 1)*2 + r/2 model chunk 1 bwd B + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + pipeline_schedule[stage].append( + ScheduledNode( + type="B", + chunk=1, + stage=stage, + minibatch=(self.n_stage // 2 - 1) * 2 + r // 2, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + # Stage i in [pp/2, pp/2 + r + 1) + for stage in range(self.n_stage // 2, self.n_stage // 2 + r + 1): + # Stage i in [pp/2, pp/2 + r+ 1), mbs (pp/2 - 1)*2 + r/2 model chunk 0 bwd B + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + pipeline_schedule[stage].append( + ScheduledNode( + type="B", + chunk=0, + stage=stage, + minibatch=(self.n_stage // 2 - 1) * 2 + r // 2, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + else: + # Stage i in [pp/2 - r - 1, pp/2) + for stage in range(self.n_stage // 2 - r - 1, self.n_stage // 2): + # Stage i in [pp/2 - r - 1, pp/2), mbs pp-1-(pp/2 - i)+ floor(r/2) model chunk 0 bwd B # [6:13,7:14] + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + pipeline_schedule[stage].append( + ScheduledNode( + type="B", + chunk=0, + stage=stage, + minibatch=self.n_stage - 1 - (self.n_stage // 2 - stage) + floor(r / 2), + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + # Stage i in [pp/2, pp/2 + r+ 1) + for stage in range(self.n_stage // 2, self.n_stage // 2 + r + 1): + # Stage i in [pp/2, pp/2 + r+ 1), mbs pp-1-(i - pp/2 + 1) + floor(r/2) model chunk 1 bwd B # [8:14,9:13] + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + pipeline_schedule[stage].append( + ScheduledNode( + type="B", + chunk=1, + stage=stage, + minibatch=self.n_stage - 1 - (stage - self.n_stage // 2 + 1) + floor(r / 2), + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + ########### Pipe_Stage 3.2 ########### + def cross_bwdB_bwdW(pipeline_schedule: List[List[ScheduledNode]]): + for stage in range(0, self.n_stage // 2 - 1): + first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=0) + # print(f"stage {stage} Up first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ") + u_queue_w, u_queue_b, d_queue_w = [], [], [] + ### 1.Get W nodes, then merge up/down W nodes ### + # get up W nodes: [first_u: mbs//2] + for _ in range(first_u, self.n_micro // 2): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + u_queue_w.append( + ScheduledNode( + type="W", + chunk=0, + stage=stage, + minibatch=_, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + # get down W nodes: [first_d: mbs//2] Bwd W to W Queue + for _ in range(first_d, self.n_micro // 2): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + d_queue_w.append( + ScheduledNode( + type="W", + chunk=1, + stage=stage, + minibatch=_, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + ### 2.Get B nodes, then cross with W ### + for _ in range(last_u, self.n_micro // 2): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + u_queue_b.append( + ScheduledNode( + type="B", + chunk=0, + stage=stage, + minibatch=_ + 1, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + # if stage % 2 == 0: u_queue_w first, then d_queue_w + if stage % 2 == 0: + w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w) + wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b) + # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B' + cut_idx = len(wb_nodes) + for _ in range(len(wb_nodes)): + if ( + wb_nodes[_].minibatch == (self.n_micro // 2 - 1) + and wb_nodes[_].type == "B" + and wb_nodes[_].chunk == 0 + ): + cut_idx = _ + break + wb_nodes = wb_nodes[: cut_idx + 1] + # append nodes to stage pipe + pipeline_schedule[stage].extend(wb_nodes) + # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}") + # else: d_queue_w first, then u_queue_w + else: + w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w) + wb_nodes = self.cross_merge_nodes(w_nodes, u_queue_b) + # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B' + cut_idx = len(wb_nodes) + for _ in range(len(wb_nodes)): + if ( + wb_nodes[_].minibatch == (self.n_micro // 2 - 1) + and wb_nodes[_].type == "B" + and wb_nodes[_].chunk == 0 + ): + cut_idx = _ + break + wb_nodes = wb_nodes[: cut_idx + 1] + # append nodes to stage pipe + pipeline_schedule[stage].extend(wb_nodes) + # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}") + + for stage in range(self.n_stage // 2 + 1, self.n_stage): + first_d, last_d, first_u, last_u = self.get_pipe_first_b_w(pipeline_schedule[stage], chunk=1) + # print(f"stage {stage} Down first_d {first_d}, last_d {last_d}, first_u {first_u}, last_u {last_u} ") + d_queue_w, d_queue_b, u_queue_w = [], [], [] + ### 1.Get W nodes, then merge down/up W nodes ### + # get down W nodes: [first_d: mbs // 2] chunk 1 + for _ in range(first_d, self.n_micro // 2): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + d_queue_w.append( + ScheduledNode( + type="W", + chunk=1, + stage=stage, + minibatch=_, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + # print(f"stage {stage} d_queue_w {[_.minibatch for _ in d_queue_w]}") + # get up W nodes: [first_u: mbs//2] chunk 0 + for _ in range(first_u, self.n_micro // 2): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + u_queue_w.append( + ScheduledNode( + type="W", + chunk=0, + stage=stage, + minibatch=_, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + # print(f"stage {stage} u_queue_w {[_.minibatch for _ in u_queue_w]}") + ### 2.Get B nodes, then cross with W ### + for _ in range(last_d, self.n_micro // 2): + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + d_queue_b.append( + ScheduledNode( + type="B", + chunk=1, + stage=stage, + minibatch=_ + 1, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + # print(f"stage {stage} d_queue_b {[_.minibatch for _ in d_queue_b]}") + # print( + # f"stage {stage} d_queue_w {[_.minibatch for _ in d_queue_w]} d_queue_b {[_.minibatch for _ in d_queue_b]} u_queue_w {[_.minibatch for _ in u_queue_w]}" + # ) + if stage % 2 == 0: + w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w) + # print(f"stage {stage} w_nodes {[_.minibatch for _ in w_nodes]}") + wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b) + # clean w nodes, let it stop at mbs // 2 - 1, chunk 1, type 'B' + cut_idx = len(wb_nodes) + for _ in range(len(wb_nodes)): + if ( + wb_nodes[_].minibatch == (self.n_micro // 2 - 1) + and wb_nodes[_].type == "B" + and wb_nodes[_].chunk == 1 + ): + cut_idx = _ + break + wb_nodes = wb_nodes[: cut_idx + 1] + # append nodes to stage pipe + pipeline_schedule[stage].extend(wb_nodes) + # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}") + else: + w_nodes = self.cross_merge_nodes(d_queue_w, u_queue_w) + # print(f"stage {stage} w_nodes {[_.minibatch for _ in w_nodes]}") + wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b) + # clean w nodes, let it stop at mbs // 2 - 1, chunk 1, type 'B' + cut_idx = len(wb_nodes) + for _ in range(len(wb_nodes)): + if ( + wb_nodes[_].minibatch == (self.n_micro // 2 - 1) + and wb_nodes[_].type == "B" + and wb_nodes[_].chunk == 1 + ): + cut_idx = _ + break + wb_nodes = wb_nodes[: cut_idx + 1] + # append nodes to stage pipe + pipeline_schedule[stage].extend(wb_nodes) + # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}") + # # else: d_queue_w first, then u_queue_w + # else: + # w_nodes = self.cross_merge_nodes(u_queue_w, d_queue_w) + # wb_nodes = self.cross_merge_nodes(w_nodes, d_queue_b) + # # clean w nodes, let it stop at mbs // 2 - 1, chunk 0, type 'B' + # cut_idx = len(wb_nodes) + # for _ in range(len(wb_nodes)): + # if ( + # wb_nodes[_].minibatch == (self.n_micro // 2 - 1) + # and wb_nodes[_].type == "B" + # and wb_nodes[_].chunk == 1 + # ): + # cut_idx = _ + # break + # wb_nodes = wb_nodes[: cut_idx + 1] + # print(f"stage {stage} cut_idx {cut_idx} wb_nodes {[str(_.minibatch) + _.type + ('u' if _.chunk == 0 else 'd') for _ in wb_nodes]}") + + ########### Pipe_Stage 3.3 ########### + def bwdW_step(pipeline_schedule: List[List[ScheduledNode]]): + # for each stage, add pp/2 round(total 8 round)Schedule Nodes + for r in range(0, self.n_stage // 2): + for stage in range(0, self.n_stage // 2): + # Red up # [0, 7] + if stage in range(self.n_stage // 2 - r - 1, self.n_stage // 2 - 1): + # Stage i in [pp/2 - r, pp/2 - 1), mbs(pp/2 - 1)*2 + 向下取整(r/2) + 1 model chunk Chunk_num bwd W # None + # mbs_num = (pp/2 - 1)*2 + 向下取整(r/2) if r < pp //2 // 2 else (pp/2 - 1)*2 + (r - pp //2 // 2) + # chunk_num = 1 if r < pp //2 else 0 + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + chunk_num = 1 if r < self.n_stage // 2 // 2 else 0 + mbs_num = ( + (self.n_stage // 2 - 1) * 2 + r + if r < self.n_stage // 2 // 2 + else (self.n_stage // 2 - 1) * 2 + (r - self.n_stage // 2 // 2) + ) + pipeline_schedule[stage].append( + ScheduledNode( + type="W", + chunk=chunk_num, + stage=stage, + minibatch=mbs_num, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + # Blue up [7, 8] + if stage in range(self.n_stage // 2 - 1, self.n_stage // 2): + # Stage i in [pp/2 - 1, pp/2), mbs (pp/2 - 1)*2 + floor(r/2) model chunk 1 bwd W + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + pipeline_schedule[stage].append( + ScheduledNode( + type="W", + chunk=1 if r % 2 == 0 else 0, + stage=stage, + minibatch=(self.n_stage // 2 - 1) * 2 + floor(r / 2), + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + for stage in range(self.n_stage // 2, self.n_stage): + # Blue down [8, 9] + if stage in range(self.n_stage // 2, self.n_stage // 2 + 1): + # Stage i in [pp/2, pp/2 + 1), mbs (pp/2 - 1)*2 + 向下取整(r/2) - 1 model chunk chunk 0 bwd W + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + pipeline_schedule[stage].append( + ScheduledNode( + type="W", + chunk=0 if r % 2 == 0 else 1, + stage=stage, + minibatch=(self.n_stage // 2 - 1) * 2 + floor(r / 2), + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + # Red Down # [9, 16] + if stage in range(self.n_stage // 2 + 1, self.n_stage // 2 + r + 1): + # Stage i in [pp/2 + 1, pp/2 + r+ 1), mbs(pp/2 - 1)*2 + floor(r/2) model chunk 0 bwd W + # mbs_num = (pp/2 - 1)*2 + 向下取整(r/2) if r < pp //2 // 2 else (pp/2 - 1)*2 + (r - pp //2 // 2) + # chunk_num = 1 if r < pp //2 else 0 + curr_time = pipeline_schedule[stage][-1].completion_time if pipeline_schedule[stage] else 0 + chunk_num = 0 if r < self.n_stage // 2 // 2 else 1 + mbs_num = ( + (self.n_stage // 2 - 1) * 2 + r + if r < self.n_stage // 2 // 2 + else (self.n_stage // 2 - 1) * 2 + (r - self.n_stage // 2 // 2) + ) + pipeline_schedule[stage].append( + ScheduledNode( + type="W", + chunk=chunk_num, + stage=stage, + minibatch=mbs_num, + start_time=curr_time, + completion_time=curr_time + self.one_time_unit, + ) + ) + curr_time += self.one_time_unit + + bwdB_step(pipeline_schedule) + cross_bwdB_bwdW(pipeline_schedule) + bwdW_step(pipeline_schedule) + + def get_dualpipe_schedule( + self, + ): + pipeline_schedule = [[] for _ in range(self.n_stage)] + self.get_warmup_schedule(pipeline_schedule) + self.get_middle_schedule(pipeline_schedule) + self.get_end_schedule(pipeline_schedule) + return pipeline_schedule diff --git a/colossalai/pipeline/schedule/v_schedule.py b/colossalai/pipeline/schedule/v_schedule.py index d3e94c0ba8c8..10415e1aa6da 100644 --- a/colossalai/pipeline/schedule/v_schedule.py +++ b/colossalai/pipeline/schedule/v_schedule.py @@ -30,6 +30,7 @@ from collections import deque from dataclasses import dataclass +from typing import List @dataclass(eq=True, frozen=True) @@ -447,3 +448,82 @@ def even_breaker(x: ScheduledNode): assert len(rollback_comm) == 0 return local_order_with_rollback + + +class DualVPipelineGraph(PipelineGraph): + """DualVPipelineGraph: A cut-in-half combination of DualPipe and Zerobubble V""" + + def __init__( + self, + n_stage, + n_micro, + f_cost, + b_cost, + w_cost, + c_cost, + f_mem, + b_mem, + w_mem, + max_mem=None, + ): + self.n_node = 6 * n_stage * n_micro + self.n_stage = n_stage + self.n_micro = n_micro + self.f_cost = f_cost + self.b_cost = b_cost + self.w_cost = w_cost + self.c_cost = c_cost + self.f_mem = f_mem + self.b_mem = b_mem + self.w_mem = w_mem + self.fbw_cost = [f_cost, b_cost, w_cost] + self.fbw_mem = [f_mem, b_mem, w_mem] + self.max_mem = max_mem or f_mem * self.n_stage * 2 + + def convert_to_dualV(self, pipeline_schedule: List[List[ScheduledNode]]) -> List[List[ScheduledNode]]: + """ + convert zbv to dualV, spec convert parital B&W to Fully Backward. To save memory for caching dx + """ + dualV_schedules = [[] for _ in range(self.n_stage)] + for stage in range(self.n_stage): + for node in pipeline_schedule[stage]: + if node.type == "B": + if node.chunk == 1 and node.minibatch in range(self.n_stage - 1 - stage, self.n_micro - 1 - stage): + dualV_schedules[stage].append( + ScheduledNode( + type="Full_B", + chunk=node.chunk, + stage=node.stage, + minibatch=node.minibatch, + start_time=node.start_time, + completion_time=node.completion_time, + ) + ) + elif node.chunk == 0 and node.minibatch in range( + self.n_micro - self.n_stage - self.n_stage + stage + ): + dualV_schedules[stage].append( + ScheduledNode( + type="Full_B", + chunk=node.chunk, + stage=node.stage, + minibatch=node.minibatch, + start_time=node.start_time, + completion_time=node.completion_time, + ) + ) + else: + dualV_schedules[stage].append(node) + elif node.type == "W": + if node.chunk == 1 and node.minibatch in range(self.n_stage - 1 - stage, self.n_micro - 1 - stage): + pass + elif node.chunk == 0 and node.minibatch in range( + self.n_micro - self.n_stage - self.n_stage + stage + ): + pass + else: + dualV_schedules[stage].append(node) + else: + dualV_schedules[stage].append(node) + + return dualV_schedules diff --git a/colossalai/pipeline/schedule/zero_bubble_pp.py b/colossalai/pipeline/schedule/zero_bubble_pp.py index edbb7118aa1a..f8bc3cde4108 100644 --- a/colossalai/pipeline/schedule/zero_bubble_pp.py +++ b/colossalai/pipeline/schedule/zero_bubble_pp.py @@ -588,6 +588,81 @@ def backward_b_step( input_obj_grad[k] = v.grad return input_obj_grad + def backward_full_b_step( + self, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + optimizer: OptimizerWrapper, + # micro_batch: Optional[dict], + input_obj: Optional[dict], + output_obj: Union[dict, torch.Tensor], + output_obj_grad: Optional[dict], + ) -> Optional[dict]: + """Fully Backward step of the pipeline; we calculate "dx = w*dy & dw = x*dy" here; + + Args: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + optimizer (OptimizerWrapper): Optimizer to update the model + input_obj (Optional[Tuple(dict)]): x. (microbatch, input_obj) + output_obj (Union[dict, torch.Tensor]): y. + output_obj_grad (dict): dy. + + Returns: + Optional[dict]: dx. + """ + # calculate fully step ; include dx = w*dy & dw = x*dy; + + # Retain the grad on the input_obj. No need retain_grad microbatch + if input_obj is not None: + tree_map(retain_grad, input_obj) + + # x, y, dy list for backward_by_grad; Type: list[tensor]; + input_obj_ = [] + output_obj_ = [] + output_obj_grad_ = [] + + # For chunk 0 stage 0, use micro_batch as input_obj_; and we don't have to cal microbatch dx. + + # For loss backward; output_obj is loss; output_obj_grad should be None + if model_chunk_id == 1 and self.stage_manager.is_first_stage(ignore_chunk=True): + assert output_obj_grad is None + input_obj_, _ = tree_flatten(input_obj) + output_obj_.append(output_obj) # LOSS + output_obj_grad_.append(output_obj_grad) # None + + # For other chunk stage, use input_obj as input_obj_; + else: + input_obj_, _ = tree_flatten(input_obj) + output_obj_, _ = tree_flatten(output_obj) # y + output_obj_grad_, _ = tree_flatten(output_obj_grad) # dy + + # filter item which is not torch.Tensor + input_obj_ = [v for v in input_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_ = [v for v in output_obj_ if isinstance(v, torch.Tensor) or v is None] + output_obj_grad_ = [v for v in output_obj_grad_ if isinstance(v, torch.Tensor) or v is None] + + try: + ctx = optimizer.no_sync() + except AttributeError: + ctx = model_chunk.no_sync() + with ctx: + optimizer.backward_by_grad( + tensor=output_obj_, + grad=output_obj_grad_, + # inputs=input_obj_, + retain_graph=False, + ) + # Format output_obj_grad + input_obj_grad = dict() + if model_chunk_id == 0 and self.stage_manager.is_first_stage(ignore_chunk=True): + pass + else: + for k, v in input_obj.items(): + if isinstance(v, torch.Tensor) and v.grad is not None: + input_obj_grad[k] = v.grad + return input_obj_grad + def backward_w_step( self, model_chunk: Union[ModuleList, Module], @@ -806,6 +881,74 @@ def schedule_b( self.send_backward_buffer[model_chunk_id].append(input_object_grad) WeightGradStore.flush(chunk=model_chunk_id) + def schedule_full_b( + self, + scheduled_node, + model_chunk: Union[ModuleList, Module], + model_chunk_id: int, + optimizer: OptimizerWrapper, + ): + """A Fully backward schedule; Include recv bwd --> cal fully bwd step --> send bwd; + + Args: + scheduled_node: + model_chunk (ModuleList or Module): Model Chunk to be run; + model_chunk_id (int): The current model chunk idx; + Returns: + Nothing. + """ + # Step1: recv bwd + if model_chunk_id == 0: + # chunk0 is last stage; recv output_grad from local_send_backward_buffer + if self.stage_manager.is_last_stage(ignore_chunk=True): + output_tensor_grad = self.local_send_backward_buffer.pop(0) + # chunk0 not last stage; recv output_grad from recv_backward_buffer + else: + output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + for h in output_tensor_grad[1]: + h.wait() + output_tensor_grad = output_tensor_grad[0] + else: + # chunk1, is first stage; recv LOSS from local send bwd buffer + if self.stage_manager.is_first_stage(ignore_chunk=True): + output_tensor_grad = None + # chunk1, not first stage; recv output_grad from recv_backward_buffer + else: + output_tensor_grad = self.recv_backward_buffer[model_chunk_id].pop(0) + for h in output_tensor_grad[1]: + h.wait() + output_tensor_grad = output_tensor_grad[0] + + # get input and output object from buffer; + input_obj = self.input_tensors[model_chunk_id].pop(0) + output_obj = self.output_tensors[model_chunk_id].pop(0) + + input_object_grad = self.backward_full_b_step( + model_chunk=model_chunk, + model_chunk_id=model_chunk_id, + optimizer=optimizer, + input_obj=input_obj, + output_obj=output_obj, + output_obj_grad=output_tensor_grad, + ) + + # Step3: send bwd + if model_chunk_id == 0: + # do nothing; end of bwd; + if self.stage_manager.is_first_stage(ignore_chunk=True): + pass + # save input_object_grad to send_backward_buffer + else: + self.send_backward_buffer[model_chunk_id].append(input_object_grad) + else: + # send to local_send_backward_buffer + if self.stage_manager.is_last_stage(ignore_chunk=True): + self.local_send_backward_buffer.append(input_object_grad) + # send to next + else: + self.send_backward_buffer[model_chunk_id].append(input_object_grad) + # WeightGradStore.flush(chunk=model_chunk_id) + def schedule_w( self, scheduled_node, @@ -919,6 +1062,15 @@ def run_forward_backward( model_chunk_id=scheduled_node.chunk, optimizer=optimizer, ) + elif scheduled_node.type == "Full_B": + WeightGradStore.enabled = False + self.schedule_full_b( + scheduled_node=scheduled_node, + model_chunk=model_chunk, + model_chunk_id=scheduled_node.chunk, + optimizer=optimizer, + ) + WeightGradStore.enabled = True elif scheduled_node.type == "W": self.schedule_w( scheduled_node=scheduled_node, diff --git a/colossalai/pipeline/weight_grad_store.py b/colossalai/pipeline/weight_grad_store.py index 1a9ef142156d..c232aeca0803 100644 --- a/colossalai/pipeline/weight_grad_store.py +++ b/colossalai/pipeline/weight_grad_store.py @@ -2,7 +2,9 @@ class WeightGradStore: - + enabled: bool = ( + True # if True: cache W in Layer, and pop to cal W; else: do not cache W, perform a full Bwd in pipeline; + ) cache = [] weight_grad_queue = [queue.Queue(), queue.Queue()] diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 0252f90e1c27..1bc61b07fb9d 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -126,7 +126,7 @@ def backward(ctx, grad_output): # split dx & dw if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad - if use_zbv: + if use_zbv and WeightGradStore.enabled: WeightGradStore.put( total_input, grad_output, @@ -146,7 +146,7 @@ def backward(ctx, grad_output): else: grad_weight = total_input.t().matmul(grad_output) else: - if use_zbv: + if use_zbv and WeightGradStore.enabled: WeightGradStore.put( total_input, grad_output, @@ -209,7 +209,7 @@ def backward(ctx, grad_output): if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad - if use_zbv: + if use_zbv and WeightGradStore.enabled: WeightGradStore.put( total_input, grad_output, @@ -229,7 +229,7 @@ def backward(ctx, grad_output): else: grad_weight = total_input.t().matmul(grad_output) else: - if use_zbv: + if use_zbv and WeightGradStore.enabled: WeightGradStore.put( total_input, grad_output, @@ -297,7 +297,7 @@ def backward(ctx, grad_output): # all-reduce scheduled first and have GPU resources allocated, CUDA_DEVICE_MAX_CONNECTIONS=1 is set in shardformer.py if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad - if use_zbv: + if use_zbv and WeightGradStore.enabled: WeightGradStore.put( total_input, grad_output, @@ -317,7 +317,7 @@ def backward(ctx, grad_output): else: grad_weight = grad_output.t().matmul(total_input) else: - if use_zbv: + if use_zbv and WeightGradStore.enabled: WeightGradStore.put( total_input, grad_output, @@ -376,7 +376,7 @@ def backward(ctx, grad_output): if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad - if use_zbv: + if use_zbv and WeightGradStore.enabled: WeightGradStore.put( total_input, grad_output, @@ -396,7 +396,7 @@ def backward(ctx, grad_output): else: grad_weight = grad_output.t().matmul(total_input) else: - if use_zbv: + if use_zbv and WeightGradStore.enabled: WeightGradStore.put( total_input, grad_output, @@ -601,7 +601,7 @@ def backward(ctx, grad_output): if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad - if use_zbv: + if use_zbv and WeightGradStore.enabled: WeightGradStore.put( total_input, grad_output, @@ -621,7 +621,7 @@ def backward(ctx, grad_output): else: grad_weight = grad_output.t().matmul(total_input) else: - if use_zbv: + if use_zbv and WeightGradStore.enabled: WeightGradStore.put( total_input, grad_output, @@ -776,7 +776,7 @@ def backward(ctx, grad_output): if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad - if use_zbv: + if use_zbv and WeightGradStore.enabled: WeightGradStore.put( total_input, grad_output, @@ -796,7 +796,7 @@ def backward(ctx, grad_output): else: grad_weight = grad_output.t().matmul(total_input) else: - if use_zbv: + if use_zbv and WeightGradStore.enabled: WeightGradStore.put( total_input, grad_output, @@ -937,7 +937,7 @@ def backward(ctx, grad_output): # split dx & dw if _grad_accum_fusion_available and weight.grad is not None: grad = weight.grad - if use_zbv: + if use_zbv and WeightGradStore.enabled: WeightGradStore.put( total_input, grad_output, @@ -957,7 +957,7 @@ def backward(ctx, grad_output): else: grad_weight = total_input.t().matmul(grad_output) else: - if use_zbv: + if use_zbv and WeightGradStore.enabled: WeightGradStore.put( total_input, grad_output, diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 2964f83f4f86..ebbe59e15949 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -21,7 +21,7 @@ from colossalai.cluster import DistCoordinator from colossalai.lazy import LazyInitContext from colossalai.nn.optimizer import HybridAdam -from colossalai.pipeline.schedule.v_schedule import PipelineGraph +from colossalai.pipeline.schedule.v_schedule import DualVPipelineGraph, PipelineGraph from colossalai.shardformer import PipelineGradientCheckpointConfig warnings.filterwarnings("ignore") @@ -94,6 +94,9 @@ def main(): parser.add_argument("--custom-ckpt", action="store_true", help="Customize checkpoint", default=False) parser.add_argument("--pp_style", default="1f1b", choices=["1f1b", "interleaved", "zbv"]) + parser.add_argument( + "--bwd_style", default="full_b", choices=["full_b", "bw", "mix_b"] + ) # full_b: all layer perform fully bwd; bw: all layer perform b&w; mix_b: combine with fully bwd and b&w; parser.add_argument("--n_chunks", default=1, help="number of model chunks", type=eval) parser.add_argument("--profile", action="store_true", help="Profile the code") parser.add_argument( @@ -222,20 +225,35 @@ def empty_init(): mem_f = 34 * config.hidden_size + 5 * config.num_attention_heads * args.max_length mem_w = -32 * config.hidden_size mem_b = -mem_w - mem_f - scheduler_nodes = PipelineGraph( - n_stage=args.pp, - n_micro=args.batch_size // args.mbs, - f_cost=1000, - b_cost=1000, - w_cost=1000, - c_cost=1, - f_mem=mem_f * 1.5, - b_mem=mem_b * 1.5, - w_mem=mem_w * 1.5, - ).get_v_schedule() + if args.bwd_style == "mix_b": + scheduler_graph = DualVPipelineGraph( + n_stage=args.pp, + n_micro=args.batch_size // args.mbs, + f_cost=1000, + b_cost=1000, + w_cost=1000, + c_cost=1, + f_mem=mem_f * 1.5, + b_mem=mem_b * 1.5, + w_mem=mem_w * 1.5, + ) + scheduler_nodes = scheduler_graph.get_v_schedule() + scheduler_nodes = scheduler_graph.convert_to_dualV(scheduler_nodes) + else: + scheduler_nodes = PipelineGraph( + n_stage=args.pp, + n_micro=args.batch_size // args.mbs, + f_cost=1000, + b_cost=1000, + w_cost=1000, + c_cost=1, + f_mem=mem_f * 1.5, + b_mem=mem_b * 1.5, + w_mem=mem_w * 1.5, + ).get_v_schedule() else: scheduler_nodes = None - + print(f"scheduler_nodes {scheduler_nodes}") plugin = HybridParallelPlugin( tp_size=args.tp, pp_size=args.pp, diff --git a/tests/test_pipeline/test_schedule/test_dualpipe_schedule.py b/tests/test_pipeline/test_schedule/test_dualpipe_schedule.py new file mode 100644 index 000000000000..3325fc1b35a0 --- /dev/null +++ b/tests/test_pipeline/test_schedule/test_dualpipe_schedule.py @@ -0,0 +1,113 @@ +from typing import List + +from colossalai.pipeline.schedule.dualpipe_schedule import DUALPIPE_NODETYPE, DualPipeGraph +from colossalai.pipeline.schedule.v_schedule import DualVPipelineGraph, PipelineGraph, ScheduledNode +from colossalai.testing import parameterize + + +def print_pipeline_details( + pipeline_schedule: List[List[ScheduledNode]], + chunk_mode: bool = False, + mbs_mode: bool = False, + empty_bubble_str_mode: bool = False, +): + assert not ( + chunk_mode and mbs_mode + ), "Only one mode is supported at the same time, please choose from chunk_mode and mbs_mode" + schedule_str = "" + for stage in range(len(pipeline_schedule)): + stage_nodes = [] + for node in pipeline_schedule[stage]: + if node.type in DUALPIPE_NODETYPE: + if node.type == "EMPTY_BUBBLE": + if empty_bubble_str_mode: + stage_nodes.append("E") + else: + stage_nodes.append(" ") + else: + if chunk_mode: + stage_nodes.append(node.type + str(node.chunk)) + elif mbs_mode: + stage_nodes.append(node.type + str(node.minibatch)) + else: + stage_nodes.append(node.type) + stage_str = "".join([_ for _ in stage_nodes]) + schedule_str += "\n" + stage_str + print(schedule_str) + + +@parameterize( + "test_config", + [ + { + "n_stage": 16, + }, + ], +) +def test_dualpipe_schedule(test_config): + dualpipe = DualPipeGraph( + n_stage=test_config["n_stage"], + n_micro=(test_config["n_stage"] + 2) * 2, + ) + dualpipe_schedule = dualpipe.get_dualpipe_schedule() + # print(dualpipe_schedule) + dualpipe.print_details( + dualpipe_schedule, + chunk_mode=True, + # mbs_mode=True, + empty_bubble_str_mode=True, + ) + + +@parameterize( + "test_config", + [ + { + "n_stage": 4, + }, + ], +) +def test_dualpipeV_schedule(test_config): + mem_f = 34 * 4096 + 5 * 24 * 4096 + mem_w = -32 * 4096 + mem_b = -mem_w - mem_f + # zbv + zbv_schedule = PipelineGraph( + n_stage=test_config["n_stage"], + n_micro=10, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f * 1.5, + b_mem=mem_b * 1.5, + w_mem=mem_w * 1.5, + ).get_v_schedule() + print_pipeline_details( + zbv_schedule, + mbs_mode=True, + ) + + # dual V + dualV_graph = DualVPipelineGraph( + n_stage=test_config["n_stage"], + n_micro=10, + f_cost=1, + b_cost=1, + w_cost=1, + c_cost=1, + f_mem=mem_f * 1.5, + b_mem=mem_b * 1.5, + w_mem=mem_w * 1.5, + ) + dualV_schedule = dualV_graph.get_v_schedule() + dualV_schedule = dualV_graph.convert_to_dualV(dualV_schedule) + print_pipeline_details( + dualV_schedule, + mbs_mode=True, + ) + + +if __name__ == "__main__": + # test_dualpipe_schedule() + test_dualpipeV_schedule()