Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 3 additions & 9 deletions colossalai/booster/plugin/moe_hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
HybridParallelPlugin,
HybridParallelZeroOptimizer,
get_param_info,
reinitialize_optimizer,
)
from colossalai.checkpoint_io import MoECheckpointIO
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
Expand Down Expand Up @@ -468,18 +467,13 @@ def configure(
use_fp8=self.use_fp8,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if self.ep_size > 1:
# if ep is enabled, the num of (moe) paramaters changed since they are sharded among ep groups
# but the optimizer is not aware of ep, so we need to update the optimizer
reinitialize_optimizer(optimizer, model)

if self.zero_stage == 0:
is_zero = False
if self.precision in ["fp16", "bf16"]:
optimizer = HybridParallelAMPOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
param_info=param_info,
precision=self.precision,
max_norm=self.max_norm,
Expand All @@ -489,7 +483,7 @@ def configure(
optimizer = HybridParallelNaiveOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
param_info=param_info,
max_norm=self.max_norm,
pp_process_group=self.pp_group,
Expand All @@ -507,7 +501,7 @@ def configure(
optimizer = MoeHybridParallelZeroOptimizer(
optimizer,
model,
use_pipeline=self.enable_pipeline_parallelism,
use_pipeline=self.enable_pipeline_parallelism or self.ep_size > 1,
param_info=param_info,
dp_process_group=self.mixed_dp_group,
tp_process_group=self.tp_group,
Expand Down
5 changes: 4 additions & 1 deletion colossalai/cluster/process_group_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,10 @@ def destroy_mesh_process_groups(self):
system resources.
"""
for group in self._ranks_to_group.values():
dist.destroy_process_group(group)
try:
dist.destroy_process_group(group)
except ValueError:
pass

# Manually clear all process groups to save memory
gc.collect()
Expand Down
8 changes: 5 additions & 3 deletions colossalai/lazy/lazy_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def _data_tolist(tensor: torch.Tensor) -> list:
return tensor.data.tolist()


def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
def _convert_cls(tensor: "LazyTensor", target: torch.Tensor, requires_grad=None) -> torch.Tensor:
"""Convert a lazy tensor's class to target's class, with target's data.

The reason why we change the class of a lazy tensor in-place is that this can easily handle shared modules/parameters, which is common in huggingface models.
Expand All @@ -117,13 +117,14 @@ def _convert_cls(tensor: "LazyTensor", target: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: the converted tensor
"""
requires_grad = target.requires_grad if requires_grad is None else requires_grad
cls_to_become = Parameter if isinstance(tensor, Parameter) else torch.Tensor
tensor.__class__ = cls_to_become
if cls_to_become is Parameter:
# to fit UninitializedParameter
delattr(tensor, "_is_param")
tensor.data = target
tensor.requires_grad = target.requires_grad
tensor.requires_grad = requires_grad
# subclass of torch.Tensor does not have tolist() method
# overwrite this method after materialization or distribution
tensor.tolist = MethodType(_data_tolist, tensor)
Expand Down Expand Up @@ -212,9 +213,10 @@ def materialize(self) -> torch.Tensor:
Returns:
torch.Tensor: The materialized tensor (self).
"""
requires_grad = self.requires_grad
target = self._materialize_data()
self.clean()
return _convert_cls(self, target)
return _convert_cls(self, target, requires_grad=requires_grad)

def clean(self) -> None:
"""Clean all stored operations, meta data and materialized data, which prevents memory leaking. This should be called after all tensors are materialized."""
Expand Down
277 changes: 277 additions & 0 deletions colossalai/shardformer/modeling/deepseek_v3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,277 @@
from typing import List, Optional, Tuple, Union

import numpy as np
import torch
import torch.distributed as dist
from torch.distributed import ProcessGroup
from transformers.cache_utils import Cache, DynamicCache
from transformers.modeling_attn_mask_utils import _prepare_4d_causal_attention_mask
from transformers.modeling_outputs import BaseModelOutputWithPast

from colossalai.lazy import LazyInitContext
from colossalai.moe._operation import (
DPGradScalerIn,
DPGradScalerOut,
EPGradScalerIn,
EPGradScalerOut,
all_to_all_uneven,
)
from colossalai.shardformer.layer.linear import ParallelModule
from colossalai.shardformer.shard.utils import set_tensors_to_none
from colossalai.tensor.moe_tensor.api import set_moe_tensor_ep_group


class EpDeepseekV3MoE(ParallelModule):
"""
A mixed expert module containing shared experts.
"""

def __init__(self, config):
raise RuntimeError(f"Please use `from_native_module` to create an instance of {self.__class__.__name__}")

def setup_process_groups(
self,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
):
assert moe_dp_group is not None
assert ep_group is not None

self.ep_size = dist.get_world_size(ep_group)
self.ep_rank = dist.get_rank(ep_group)
self.num_experts = self.config.n_routed_experts
assert self.num_experts % self.ep_size == 0

self.ep_group = ep_group
self.num_experts_per_ep = self.num_experts // self.ep_size
self.experts_per_rank = self.num_experts_per_ep
self.expert_start_idx = self.ep_rank * self.num_experts_per_ep
held_experts = self.experts[self.expert_start_idx : self.expert_start_idx + self.num_experts_per_ep]

set_tensors_to_none(self.experts, exclude=set(held_experts))

# setup moe_dp group
self.moe_dp_group = moe_dp_group
self.moe_dp_size = dist.get_world_size(moe_dp_group)

for p in self.experts.parameters():
set_moe_tensor_ep_group(p, ep_group)

@staticmethod
def from_native_module(
module,
moe_dp_group: ProcessGroup,
ep_group: ProcessGroup,
*args,
**kwargs,
) -> "EpDeepseekV3MoE":
if module.__class__.__name__ != "DeepseekV3MLP":
module.__class__ = EpDeepseekV3MoE
module.setup_process_groups(moe_dp_group, ep_group)
LazyInitContext.materialize(module)
return module

def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
identity = hidden_states
orig_shape = hidden_states.shape
topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
y = self.moe_forward(hidden_states, topk_idx, topk_weight).view(*orig_shape)
if self.config.n_shared_experts is not None:
y = y + self.shared_experts(identity)
return y

def moe_forward(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
cnts = topk_ids.new_zeros((topk_ids.shape[0], len(self.experts)))
cnts.scatter_(1, topk_ids, 1)
tokens_per_expert = cnts.sum(dim=0)
idxs = topk_ids.view(-1).argsort()
sorted_tokens = x[idxs // topk_ids.shape[1]]
if self.ep_size > 1:
tokens_per_ep_rank = tokens_per_expert.view(self.ep_size, -1).sum(dim=1)
tokens_per_expert_group = tokens_per_expert.new_empty(tokens_per_expert.shape[0])
dist.all_to_all_single(tokens_per_expert_group, tokens_per_expert, group=self.ep_group)

output_splits = tokens_per_expert_group.view(self.ep_size, -1).sum(1).tolist()
input_split_sizes = tokens_per_ep_rank.tolist()

gathered_tokens, _ = all_to_all_uneven(sorted_tokens, input_split_sizes, output_splits, self.ep_group)
tokens_per_expert_post_gather = tokens_per_expert_group.view(self.ep_size, self.experts_per_rank).sum(dim=0)
gatherd_idxs = np.zeros(shape=(gathered_tokens.shape[0],), dtype=np.int32)
s = 0
for i, k in enumerate(tokens_per_expert_group.cpu().numpy()):
gatherd_idxs[s : s + k] = i % self.experts_per_rank
s += k
gatherd_idxs = gatherd_idxs.argsort()
sorted_tokens = gathered_tokens[gatherd_idxs]
tokens_per_expert = tokens_per_expert_post_gather

# moe-dp related code
activate_experts = tokens_per_expert_post_gather > 0
activate_experts = activate_experts.int()
dist.all_reduce(activate_experts, group=self.moe_dp_group)

# ep related code
sorted_tokens = EPGradScalerIn.apply(sorted_tokens, self.ep_size)

tokens_per_expert = tokens_per_expert.cpu().numpy()

outputs = []
start_idx = 0
for i, num_tokens in enumerate(tokens_per_expert):
end_idx = start_idx + num_tokens
if num_tokens == 0:
continue
expert = self.experts[i + self.ep_rank * self.experts_per_rank]
tokens_for_this_expert = sorted_tokens[start_idx:end_idx]
# moe-dp related code
tokens_for_this_expert = DPGradScalerIn.apply(tokens_for_this_expert, self.moe_dp_size, activate_experts[i])
expert_out = expert(tokens_for_this_expert)
# moe-dp related code
expert_out = DPGradScalerOut.apply(expert_out, self.moe_dp_size, activate_experts[i])
outputs.append(expert_out)
start_idx = end_idx

if len(outputs) > 0:
outs = torch.cat(outputs, dim=0)
else:
assert sorted_tokens.numel() == 0, f"sorted_tokens: should be empty, but got {sorted_tokens.shape}"
outs = sorted_tokens

if self.ep_size > 1:
outs = EPGradScalerOut.apply(outs, self.ep_size)
new_x = torch.empty_like(outs)
new_x[gatherd_idxs] = outs
gathered_tokens, _ = all_to_all_uneven(new_x, output_splits, input_split_sizes, self.ep_group)
outs = gathered_tokens

new_x = torch.empty_like(outs)
new_x[idxs] = outs
final_out = (
(new_x.view(*topk_ids.shape, -1).type(topk_weight.dtype) * topk_weight.unsqueeze(dim=-1))
.sum(dim=1)
.type(new_x.dtype)
)

return final_out


def deepseek_v3_model_forward(
self,
input_ids: torch.LongTensor = None,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
inputs_embeds: Optional[torch.FloatTensor] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple, BaseModelOutputWithPast]:
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
output_hidden_states = (
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

# retrieve input_ids and inputs_embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
elif input_ids is not None:
batch_size, seq_length = input_ids.shape[:2]
elif inputs_embeds is not None:
batch_size, seq_length = inputs_embeds.shape[:2]
else:
raise ValueError("You have to specify either input_ids or inputs_embeds")

past_key_values_length = 0
if use_cache:
use_legacy_cache = not isinstance(past_key_values, Cache)
if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length)

if position_ids is None:
device = input_ids.device if input_ids is not None else inputs_embeds.device
position_ids = torch.arange(
past_key_values_length,
seq_length + past_key_values_length,
dtype=torch.long,
device=device,
)
position_ids = position_ids.unsqueeze(0)

if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)

if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
else:
# 4d mask is passed through the layers
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
inputs_embeds,
past_key_values_length,
)

# embed positions
hidden_states = inputs_embeds

# decoder layers
all_hidden_states = () if output_hidden_states else None
all_self_attns = () if output_attentions else None
next_decoder_cache = None

for i, decoder_layer in enumerate(self.layers):
if output_hidden_states:
all_hidden_states += (hidden_states,)

if self.gradient_checkpointing and i > 0:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
position_ids,
past_key_values,
output_attentions,
use_cache,
)
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
)

hidden_states = layer_outputs[0]

if use_cache:
next_decoder_cache = layer_outputs[2 if output_attentions else 1]

if output_attentions:
all_self_attns += (layer_outputs[1],)

hidden_states = self.norm(hidden_states)

# add hidden states from the last decoder layer
if output_hidden_states:
all_hidden_states += (hidden_states,)

next_cache = None
if use_cache:
next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
if not return_dict:
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
return BaseModelOutputWithPast(
last_hidden_state=hidden_states,
past_key_values=next_cache,
hidden_states=all_hidden_states,
attentions=all_self_attns,
)
7 changes: 7 additions & 0 deletions colossalai/shardformer/policies/auto_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ class PolicyLocation:
"transformers_modules.modeling_deepseek.DeepseekForCausalLM": PolicyLocation(
file_name="deepseek", class_name="DeepseekForCausalLMPolicy"
),
# DeepseekV3
"transformers_modules.modeling_deepseek.DeepseekV3Model": PolicyLocation(
file_name="deepseek_v3", class_name="DeepseekV3ModelPolicy"
),
"transformers_modules.modeling_deepseek.DeepseekV3ForCausalLM": PolicyLocation(
file_name="deepseek_v3", class_name="DeepseekV3ForCausalLMPolicy"
),
# Falcon
"transformers.models.falcon.modeling_falcon.FalconModel": PolicyLocation(
file_name="falcon", class_name="FalconModelPolicy"
Expand Down
Loading