diff --git a/applications/ColossalChat/coati/dataset/loader.py b/applications/ColossalChat/coati/dataset/loader.py
index 16fd385bad30..bd91d1b8d5c8 100755
--- a/applications/ColossalChat/coati/dataset/loader.py
+++ b/applications/ColossalChat/coati/dataset/loader.py
@@ -396,6 +396,7 @@ def apply_chat_template_and_mask(
tokens = tokens[:max_length]
assistant_mask = assistant_mask[:max_length]
attention_mask = attention_mask[:max_length]
+
input_ids = torch.tensor(tokens, dtype=torch.long)
attention_mask = torch.tensor(attention_mask, dtype=torch.long)
labels = input_ids.clone()
diff --git a/applications/ColossalChat/coati/distributed/grpo_consumer.py b/applications/ColossalChat/coati/distributed/grpo_consumer.py
index 754f780979a9..2a3c2d8051a0 100644
--- a/applications/ColossalChat/coati/distributed/grpo_consumer.py
+++ b/applications/ColossalChat/coati/distributed/grpo_consumer.py
@@ -132,6 +132,8 @@ def __init__(
eta_min=0.1 * grpo_config.get("lr", 1e-6),
)
+ self.adv = grpo_config.get("algo")
+
def setup(self):
super().setup()
if (not self.plugin.pp_size > 1 and self.rank == 0) or (
@@ -180,23 +182,72 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
response_length = torch.sum(action_mask, dim=1).to(torch.float32)
train_microbatch_size = self.grpo_config.get("train_microbatch_size", data["input_ids"].size(0))
- reward = data["reward"].view((-1))
- format_acc = data["format_acc"].view((-1))
- ans_acc = data["ans_acc"].view((-1))
+ # if(True):
+
+ if self.adv == "GRPO" or self.adv == "DAPO":
+
+ reward = data["reward"].view((-1))
+ format_acc = data["format_acc"].view((-1))
+ ans_acc = data["ans_acc"].view((-1))
+
+ # [minibatch_size, num_generations]
+
+ group_reward = reward.view(-1, self.num_generations)
+ reward_mean = group_reward.mean(dim=1)
+ # [minibatch_size x num_generations]
+ reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
+
+ reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
+ # [minibatch_size x num_generations]
+ advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
+
+ # [minibatch_size x num_of_generation]
+ loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
+
+ elif self.adv == "REINFORCE_PPB":
+
+ reward = data["reward"].view((-1))
+ format_acc = data["format_acc"].view((-1))
+ ans_acc = data["ans_acc"].view((-1))
+
+ # [minibatch_size, num_generations]
+
+ group_reward = reward.view(-1, self.num_generations)
+ reward_mean = group_reward.mean(dim=1)
+ # [minibatch_size x num_generations]
+ reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
+
+ # [minibatch_size x num_generations]
+ advantages = ((reward - reward_mean)).unsqueeze(dim=-1)
+
+ advantages_mean = advantages.mean(dim=0)
+
+ advantages_std = advantages.std(dim=0)
+
+ advantages = (advantages - advantages_mean) / (advantages_std + 1e-4)
+
+ # [minibatch_size x num_of_generation]
+ loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
+
+ elif self.adv == "RLOO":
+ reward = data["reward"].view((-1))
+ format_acc = data["format_acc"].view((-1))
+ ans_acc = data["ans_acc"].view((-1))
- # [minibatch_size, num_generations]
+ # [minibatch_size, num_generations]
- group_reward = reward.view(-1, self.num_generations)
- reward_mean = group_reward.mean(dim=1)
- # [minibatch_size x num_generations]
- reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
+ group_reward = reward.view(-1, self.num_generations)
+ reward_mean = group_reward.mean(dim=1)
+ # [minibatch_size x num_generations]
+ reward_mean = reward_mean.repeat_interleave(self.num_generations, dim=0)
- reward_std = group_reward.std(dim=1).repeat_interleave(self.num_generations, dim=0)
- # [minibatch_size x num_generations]
- advantages = ((reward - reward_mean) / (reward_std + 1e-4)).unsqueeze(dim=-1)
+ advantages = (
+ reward * self.num_generations / (self.num_generations - 1)
+ - reward_mean * self.num_generations / (self.num_generations - 1)
+ ).unsqueeze(dim=-1)
- # [minibatch_size x num_of_generation]
- loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
+ # [minibatch_size x num_of_generation]
+ loss_mask = torch.ones(action_mask.size(0), device=action_mask.device).bool()
# filter out overlength samples
if self.filter_truncated_response and action_mask.size(1) == self.max_length:
diff --git a/applications/ColossalChat/coati/distributed/launch.py b/applications/ColossalChat/coati/distributed/launch.py
index a48246c87526..e24a8b99e6ff 100644
--- a/applications/ColossalChat/coati/distributed/launch.py
+++ b/applications/ColossalChat/coati/distributed/launch.py
@@ -9,7 +9,13 @@
from .grpo_consumer import GRPOConsumer
from .producer import SimpleProducer
-ALGO_MAP = {"Simple": SimpleConsumer, "GRPO": GRPOConsumer, "DAPO": GRPOConsumer}
+ALGO_MAP = {
+ "Simple": SimpleConsumer,
+ "GRPO": GRPOConsumer,
+ "DAPO": GRPOConsumer,
+ "REINFORCE_PPB": GRPOConsumer,
+ "RLOO": GRPOConsumer,
+}
def get_jsonl_size_fast(path: str) -> int:
diff --git a/applications/ColossalChat/coati/distributed/producer.py b/applications/ColossalChat/coati/distributed/producer.py
index 11fb5d3aa3a9..856f27b08e13 100644
--- a/applications/ColossalChat/coati/distributed/producer.py
+++ b/applications/ColossalChat/coati/distributed/producer.py
@@ -115,6 +115,9 @@ def __init__(
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_config)
self.tokenizer.padding_side = "left"
+ if self.tokenizer.pad_token_id is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
# init dataloader
train_dataset_path = train_dataset_config.pop("path")
self.train_dataset = RawConversationDataset(self.tokenizer, train_dataset_path, **train_dataset_config)
diff --git a/applications/ColossalChat/coati/distributed/untitled.txt b/applications/ColossalChat/coati/distributed/untitled.txt
new file mode 100644
index 000000000000..e74bd2b0dc6d
--- /dev/null
+++ b/applications/ColossalChat/coati/distributed/untitled.txt
@@ -0,0 +1,2 @@
+4.51.0: qwen2.5 + grpo, qwen3 + grpo, cannot: llama2, llama3.2
+4.47.0:
diff --git a/applications/ColossalChat/rl_example.py b/applications/ColossalChat/rl_example.py
index 08814f9f1e61..e5d41cccae11 100644
--- a/applications/ColossalChat/rl_example.py
+++ b/applications/ColossalChat/rl_example.py
@@ -130,7 +130,7 @@
)
# GRPO parameters
- parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO"])
+ parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["DAPO", "GRPO", "REINFORCE_PPB", "RLOO"])
parser.add_argument("-lr", "--learning-rate", type=float, default=1e-6, help="Learning rate for GRPO.")
parser.add_argument("-kl", "--kl-coeff", type=float, default=0.01, help="KL penalty coefficient for GRPO.")
parser.add_argument(
@@ -227,13 +227,13 @@
os.environ["TOKENIZERS_PARALLELISM"] = "false" # Disable tokenizers parallelism to avoid deadlock
inference_model_config = dict(path=args.model)
- train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
+ train_model_config = dict(path=args.model, use_flash_attention_2=False, use_cache=False)
generate_config = dict(top_k=args.top_k, top_p=args.top_p, temperature=args.temperature)
if args.backend == "transformers":
inference_model_config.update(
dict(
- use_flash_attention_2=True,
+ use_flash_attention_2=False,
torch_dtype=torch.bfloat16,
)
)
@@ -283,6 +283,7 @@
if args.algo == "GRPO":
# Default Settings
grpo_config = {
+ "algo": "GRPO",
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
"beta": args.kl_coeff, # KL penalty coefficient
@@ -304,6 +305,7 @@
elif args.algo == "DAPO":
# DAPO variant settings
grpo_config = {
+ "algo": "DAPO",
"filter_range": [0.01, 0.99], # only filter out all zero batch and all one batch
"lr": args.learning_rate,
"train_microbatch_size": args.train_microbatch_size,
@@ -330,6 +332,50 @@
else None
),
}
+ elif args.algo == "REINFORCE_PPB":
+ # Default Settings
+ grpo_config = {
+ "algo": "REINFORCE_PPB",
+ "lr": args.learning_rate,
+ "train_microbatch_size": args.train_microbatch_size,
+ "beta": args.kl_coeff, # KL penalty coefficient
+ "loss_variation": "sample_level",
+ "reward_fn_type": args.reward_type,
+ "max_length": args.max_new_tokens + args.max_prompt_tokens,
+ "max_new_tokens": args.max_new_tokens,
+ "response_format_tags": (
+ {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+ }
+ if args.reward_type == "think_answer_tags"
+ else None
+ ),
+ }
+ elif args.algo == "RLOO":
+ # Default Settings
+ grpo_config = {
+ "algo": "RLOO",
+ "lr": args.learning_rate,
+ "train_microbatch_size": args.train_microbatch_size,
+ "beta": args.kl_coeff, # KL penalty coefficient
+ "loss_variation": "sample_level",
+ "reward_fn_type": args.reward_type,
+ "max_length": args.max_new_tokens + args.max_prompt_tokens,
+ "max_new_tokens": args.max_new_tokens,
+ "response_format_tags": (
+ {
+ "think_start": {"text": "", "num_occur": 1},
+ "think_end": {"text": "", "num_occur": 1},
+ "answer_start": {"text": "", "num_occur": 1},
+ "answer_end": {"text": "", "num_occur": 1},
+ }
+ if args.reward_type == "think_answer_tags"
+ else None
+ ),
+ }
else:
raise ValueError(f"Unsupported algorithm: {args.algo}")
if args.reward_type == "code":
diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py
index d1ad846044df..2627ffcfa4f5 100644
--- a/colossalai/shardformer/modeling/llama.py
+++ b/colossalai/shardformer/modeling/llama.py
@@ -141,7 +141,9 @@ def llama_model_forward(
invert=(sp_mode != "ring_attn"),
)
else:
- attn_kwargs: torch.Tensor = self._update_causal_mask(attention_mask, hidden_states, cache_position)
+ attn_kwargs: torch.Tensor = self._update_causal_mask(
+ attention_mask, hidden_states, cache_position, None, False
+ )
# Support SP + PP. Later stages have already received the split input.
split_input = disable_pp or stage_manager.is_first_stage()
diff --git a/colossalai/shardformer/modeling/qwen3.py b/colossalai/shardformer/modeling/qwen3.py
new file mode 100644
index 000000000000..5e8c0762c9fa
--- /dev/null
+++ b/colossalai/shardformer/modeling/qwen3.py
@@ -0,0 +1,831 @@
+# Modifed from qwen2 modeling
+import math
+from typing import List, Optional, Tuple, Union
+
+import torch
+from torch import nn
+from transformers.modeling_attn_mask_utils import (
+ _prepare_4d_causal_attention_mask,
+ _prepare_4d_causal_attention_mask_for_sdpa,
+)
+from transformers.modeling_outputs import (
+ BaseModelOutputWithPast,
+ CausalLMOutputWithPast,
+ SequenceClassifierOutputWithPast,
+)
+from transformers.models.qwen3.modeling_qwen3 import (
+ Qwen3Attention,
+ Qwen3ForCausalLM,
+ Qwen3ForSequenceClassification,
+ Qwen3Model,
+ apply_rotary_pos_emb,
+ repeat_kv,
+)
+from transformers.utils import logging
+
+from colossalai.pipeline.stage_manager import PipelineStageManager
+from colossalai.shardformer.layer._operation import all_to_all_comm, split_forward_gather_backward
+from colossalai.shardformer.shard import ShardConfig
+
+from ..layer import ColoAttention, dist_cross_entropy
+from ..layer._operation import gather_sp_output
+from ..layer.utils import is_share_sp_tp
+
+
+class Qwen3PipelineForwards:
+ """
+ This class serves as a micro library for forward function substitution of Qwen3 models
+ under pipeline setting.
+ """
+
+ @staticmethod
+ def qwen3_model_forward(
+ self: Qwen3Model,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ force_sp_output_gather: bool = True,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ logger = logging.get_logger(__name__)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if stage_manager.is_first_stage():
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+ hidden_states = inputs_embeds
+ else:
+ input_shape = hidden_states.shape[:-1]
+ batch_size, seq_length = input_shape
+ device = hidden_states.device
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+ if use_cache:
+ logger.warning_once("use_cache=True is not supported for pipeline models at the moment.")
+ use_cache = False
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ # Support SP + PP
+ sp_size = shard_config.sequence_parallel_size
+ sp_group = shard_config.sequence_parallel_process_group
+ sp_mode = shard_config.sequence_parallelism_mode
+ # For generating full positions ids (the states will be gathered along the seq dim before attention fwd).
+ if sp_mode != "ring_attn" and not stage_manager.is_first_stage():
+ seq_length *= sp_size
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # embed positions, for the first stage, hidden_states is the input embeddings,
+ # for the other stages, hidden_states is the output of the previous stage
+ if shard_config.enable_flash_attention:
+ # in this case, attention_mask is a dict rather than a tensor
+ mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
+ attention_mask = ColoAttention.prepare_attn_kwargs(
+ mask_shape,
+ hidden_states.dtype,
+ hidden_states.device,
+ q_padding_mask=attention_mask,
+ is_causal=True,
+ )
+ else:
+ if self.config._attn_implementation == "flash_attention_2":
+ # 2d mask is passed through the layers
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
+ elif self.config._attn_implementation == "sdpa" and not output_attentions:
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
+ # the manual implementation that requires a 4D causal mask in all cases.
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
+ attention_mask,
+ (batch_size, seq_length),
+ hidden_states,
+ past_key_values_length,
+ )
+ else:
+ # 4d mask is passed through the layers
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ hidden_states,
+ past_key_values_length,
+ sliding_window=self.config.sliding_window,
+ )
+
+ if stage_manager.is_first_stage():
+ if shard_config.enable_sequence_parallelism:
+ if is_share_sp_tp(sp_mode):
+ hidden_states = split_forward_gather_backward(
+ hidden_states,
+ dim=1,
+ process_group=sp_group,
+ )
+ elif sp_mode == "all_to_all":
+ hidden_states = split_forward_gather_backward(
+ hidden_states,
+ dim=1,
+ process_group=sp_group,
+ grad_scale=1 / sp_size,
+ )
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ start_idx, end_idx = stage_index[0], stage_index[1]
+ num_ckpt_layers = 0
+ if self.gradient_checkpointing and self.training:
+ num_ckpt_layers = end_idx - start_idx
+ # TODO: We can replace `gradient_checkpointing_enable` fn and initialize a gradient_checkpointing (List[bool]) for each layer
+ if shard_config.gradient_checkpoint_config is not None:
+ num_ckpt_layers = shard_config.gradient_checkpoint_config.get_num_ckpt_layers(
+ stage=stage_manager.stage,
+ num_stages=stage_manager.num_stages,
+ num_layers=end_idx - start_idx,
+ model_chunk_id=(stage_manager.model_chunk_id if stage_manager.is_interleave else 0),
+ num_model_chunks=stage_manager.num_model_chunks,
+ )
+ assert num_ckpt_layers <= end_idx - start_idx
+
+ for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ past_key_values[idx] if past_key_values is not None else None
+
+ if idx - start_idx < num_ckpt_layers:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if use_cache:
+ next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ if stage_manager.is_last_stage():
+ hidden_states = self.norm(hidden_states)
+ if shard_config.enable_sequence_parallelism:
+ if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
+ hidden_states = gather_sp_output(hidden_states, shard_config)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ if stage_manager.is_last_stage():
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+ # always return dict for imediate stage
+ return {"hidden_states": hidden_states}
+
+ @staticmethod
+ def qwen3_for_causal_lm_forward(
+ self: Qwen3ForCausalLM,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
+
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
+ ```"""
+ logger = logging.get_logger(__name__)
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = Qwen3PipelineForwards.qwen3_model_forward(
+ self.model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config,
+ force_sp_output_gather=False,
+ )
+ past_key_values = None
+
+ if stage_manager.is_last_stage():
+ hidden_states = outputs[0]
+ if hidden_states.shape[1] == 2:
+ pass
+ logits = self.lm_head(hidden_states)
+ loss = None
+ if labels is not None:
+ loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+ else:
+ hidden_states = outputs.get("hidden_states")
+ return {"hidden_states": hidden_states}
+
+ @staticmethod
+ def qwen3_for_sequence_classification_forward(
+ self: Qwen3ForSequenceClassification,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ stage_manager: Optional[PipelineStageManager] = None,
+ hidden_states: Optional[torch.FloatTensor] = None,
+ stage_index: Optional[List[int]] = None,
+ shard_config: ShardConfig = None,
+ ):
+ r"""
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
+ """
+ logger = logging.get_logger(__name__)
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # TODO(jianghai): left the recording kv-value tensors as () or None type, this feature may be added in the future.
+ if output_attentions:
+ logger.warning_once("output_attentions=True is not supported for pipeline models at the moment.")
+ output_attentions = False
+ if output_hidden_states:
+ logger.warning_once("output_hidden_states=True is not supported for pipeline models at the moment.")
+ output_hidden_states = False
+
+ transformer_outputs = Qwen3PipelineForwards.qwen3_model_forward(
+ self.model,
+ input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ stage_manager=stage_manager,
+ hidden_states=hidden_states,
+ stage_index=stage_index,
+ shard_config=shard_config,
+ )
+
+ if input_ids is not None:
+ batch_size = input_ids.shape[0]
+ elif inputs_embeds is not None:
+ batch_size = inputs_embeds.shape[0]
+ else:
+ batch_size = hidden_states.shape[0]
+
+ if stage_manager.is_last_stage():
+ hidden_states = transformer_outputs[0]
+ logits = self.score(hidden_states)
+
+ if self.config.pad_token_id is None and batch_size != 1:
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
+
+ if self.config.pad_token_id is None:
+ last_non_pad_token = -1
+ elif input_ids is not None:
+ # To handle both left- and right- padding, we take the rightmost token that is not equal to pad_token_id
+ non_pad_mask = (input_ids != self.config.pad_token_id).to(logits.device, torch.int32)
+ token_indices = torch.arange(input_ids.shape[-1], device=logits.device, dtype=torch.int32)
+ last_non_pad_token = (token_indices * non_pad_mask).argmax(-1)
+ else:
+ last_non_pad_token = -1
+ logger.warning_once(
+ f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
+ )
+
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), last_non_pad_token]
+
+ loss = None
+ if labels is not None:
+ loss = self.loss_function(logits=logits, labels=labels, pooled_logits=pooled_logits, config=self.config)
+
+ if not return_dict:
+ output = (pooled_logits,) + transformer_outputs[1:]
+ return ((loss,) + output) if loss is not None else output
+
+ return SequenceClassifierOutputWithPast(
+ loss=loss,
+ logits=pooled_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ else:
+ hidden_states = transformer_outputs.get("hidden_states")
+ return {"hidden_states": hidden_states}
+
+
+def get_qwen3_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
+ def forward(
+ self: Qwen3Attention,
+ hidden_states: torch.Tensor,
+ position_embeddings: Tuple[torch.Tensor, torch.Tensor],
+ attention_mask: Optional[torch.Tensor] = None,
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
+ if sp_mode is not None:
+ assert sp_mode in ["all_to_all", "split_gather", "ring"], "Invalid sp_mode"
+ assert (sp_size is not None) and (
+ sp_group is not None
+ ), "Must specify sp_size and sp_group for sequence parallel"
+
+ bsz, q_len, _ = hidden_states.size()
+ # sp: modify sp_len when sequence parallel mode is ring
+ if sp_mode in ["split_gather", "ring"]:
+ q_len *= sp_size
+
+ query_states = self.q_proj(hidden_states)
+ key_states = self.k_proj(hidden_states)
+ value_states = self.v_proj(hidden_states)
+ # sp: all-to-all comminucation when introducing sequence parallel
+ if sp_mode == "all_to_all":
+ query_states = all_to_all_comm(query_states, sp_group, fp8_communication=shard_config.fp8_communication)
+ key_states = all_to_all_comm(key_states, sp_group, fp8_communication=shard_config.fp8_communication)
+ value_states = all_to_all_comm(value_states, sp_group, fp8_communication=shard_config.fp8_communication)
+ bsz, q_len, _ = query_states.size()
+
+ query_states = self.q_norm(query_states.view(bsz, q_len, -1, self.head_dim)).transpose(1, 2)
+ key_states = self.k_norm(key_states.view(bsz, q_len, -1, self.head_dim)).transpose(1, 2)
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
+
+ kv_seq_len = key_states.shape[-2]
+ if past_key_value is not None:
+ if self.layer_idx is None:
+ raise ValueError(
+ f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
+ "for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
+ "with a layer index."
+ )
+ kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
+ # Because the input can be padded, the absolute sequence length depends on the max position id.
+ cos, sin = position_embeddings
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
+
+ if past_key_value is not None:
+ # Activate slicing cache only if the config has a value `sliding_windows` attribute
+ cache_has_contents = past_key_value.get_seq_length(self.layer_idx) > 0
+ if (
+ getattr(self.config, "sliding_window", None) is not None
+ and kv_seq_len > self.config.sliding_window
+ and cache_has_contents
+ ):
+ slicing_tokens = 1 - self.config.sliding_window
+
+ past_key = past_key_value[self.layer_idx][0]
+ past_value = past_key_value[self.layer_idx][1]
+
+ past_key = past_key[:, :, slicing_tokens:, :].contiguous()
+ past_value = past_value[:, :, slicing_tokens:, :].contiguous()
+
+ if past_key.shape[-2] != self.config.sliding_window - 1:
+ raise ValueError(
+ f"past key must have a shape of (`batch_size, num_heads, self.config.sliding_window-1, head_dim`), got"
+ f" {past_key.shape}"
+ )
+
+ if attention_mask is not None:
+ attention_mask = attention_mask[:, slicing_tokens:]
+ attention_mask = torch.cat([attention_mask, torch.ones_like(attention_mask[:, -1:])], dim=-1)
+
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
+
+ # repeat k/v heads if n_kv_heads < n_heads
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
+
+ if shard_config.enable_flash_attention:
+ assert isinstance(attention_mask, dict), "Flash Attention Error: attention_mask should be a dict."
+ attn_output = ColoAttention.attention(
+ query_states,
+ key_states,
+ value_states,
+ dropout_p=0.0 if not self.training else self.attention_dropout,
+ **attention_mask,
+ )
+ else:
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
+
+ if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
+ f" {attn_weights.size()}"
+ )
+
+ if attention_mask is not None:
+ if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
+ raise ValueError(
+ f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
+ )
+ attn_weights = attn_weights + attention_mask
+
+ # upcast attention to fp32
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
+ attn_output = torch.matmul(attn_weights, value_states)
+
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
+ raise ValueError(
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
+ f" {attn_output.size()}"
+ )
+ attn_output = attn_output.transpose(1, 2).contiguous()
+ if sp_mode == "all_to_all":
+ attn_output = attn_output.reshape(bsz, q_len, self.num_heads * self.head_dim)
+ attn_output = all_to_all_comm(
+ attn_output, sp_group, scatter_dim=1, gather_dim=2, fp8_communication=shard_config.fp8_communication
+ )
+ else:
+ attn_output = attn_output.reshape(bsz, q_len, -1)
+
+ attn_output = self.o_proj(attn_output)
+
+ return attn_output, None
+
+ return forward
+
+
+def get_qwen3_model_forward_for_flash_attn(shard_config: ShardConfig, sp_mode=None, sp_size=None, sp_group=None):
+ logger = logging.get_logger(__name__)
+
+ def forward(
+ self,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ cache_position: Optional[torch.LongTensor] = None,
+ return_dict: Optional[bool] = None,
+ force_sp_output_gather: bool = True,
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
+
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # retrieve input_ids and inputs_embeds
+ if input_ids is not None and inputs_embeds is not None:
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
+ elif input_ids is not None:
+ batch_size, seq_length = input_ids.shape
+ elif inputs_embeds is not None:
+ batch_size, seq_length, _ = inputs_embeds.shape
+ else:
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
+
+ if inputs_embeds is None:
+ inputs_embeds = self.embed_tokens(input_ids)
+
+ seq_length_with_past = seq_length
+ past_key_values_length = 0
+
+ if past_key_values is not None:
+ past_key_values_length = past_key_values[0][0].shape[2]
+ seq_length_with_past = seq_length_with_past + past_key_values_length
+
+ if position_ids is None:
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
+ position_ids = torch.arange(
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
+ )
+ position_ids = position_ids.unsqueeze(0).view(-1, seq_length)
+ else:
+ position_ids = position_ids.view(-1, seq_length).long()
+
+ # embed positions
+ hidden_states = inputs_embeds
+
+ if shard_config.enable_flash_attention:
+ # in this case, attention_mask is a dict rather than a tensor
+ mask_shape = (batch_size, 1, seq_length, seq_length_with_past)
+ attention_mask = ColoAttention.prepare_attn_kwargs(
+ mask_shape,
+ hidden_states.dtype,
+ hidden_states.device,
+ q_padding_mask=attention_mask,
+ is_causal=True,
+ )
+ else:
+ attention_mask = _prepare_4d_causal_attention_mask(
+ attention_mask,
+ (batch_size, seq_length),
+ inputs_embeds,
+ past_key_values_length,
+ sliding_window=self.config.sliding_window,
+ )
+
+ if (self.gradient_checkpointing or sp_mode in ["ring", "all_to_all"]) and self.training:
+ if use_cache:
+ logger.warning_once(
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
+ )
+ use_cache = False
+
+ # decoder layers
+ all_hidden_states = () if output_hidden_states else None
+ all_self_attns = () if output_attentions else None
+ next_decoder_cache = None
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
+
+ if sp_mode in ["ring", "split_gather"]:
+ hidden_states = split_forward_gather_backward(
+ hidden_states, 1, sp_group, fp8_communication=shard_config.fp8_communication
+ )
+ elif sp_mode == "all_to_all":
+ hidden_states = split_forward_gather_backward(
+ hidden_states, 1, sp_group, 1 / sp_size, fp8_communication=shard_config.fp8_communication
+ )
+
+ for decoder_layer in self.layers:
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ if self.gradient_checkpointing and self.training:
+ layer_outputs = self._gradient_checkpointing_func(
+ decoder_layer.__call__,
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+ else:
+ layer_outputs = decoder_layer(
+ hidden_states,
+ attention_mask,
+ position_ids,
+ past_key_values,
+ output_attentions,
+ use_cache,
+ cache_position,
+ position_embeddings,
+ )
+
+ hidden_states = layer_outputs[0]
+
+ if output_attentions:
+ all_self_attns += (layer_outputs[1],)
+
+ hidden_states = self.norm(hidden_states)
+
+ if shard_config.enable_sequence_parallelism:
+ if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
+ hidden_states = gather_sp_output(hidden_states, shard_config)
+
+ # add hidden states from the last decoder layer
+ if output_hidden_states:
+ all_hidden_states += (hidden_states,)
+
+ next_cache = next_decoder_cache if use_cache else None
+
+ if not return_dict:
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
+ return BaseModelOutputWithPast(
+ last_hidden_state=hidden_states,
+ past_key_values=next_cache,
+ hidden_states=all_hidden_states,
+ attentions=all_self_attns,
+ )
+
+ return forward
+
+
+def get_lm_forward_with_dist_cross_entropy(shard_config: ShardConfig):
+ def forward(
+ self: Qwen3ForCausalLM,
+ input_ids: torch.LongTensor = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ position_ids: Optional[torch.LongTensor] = None,
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
+ inputs_embeds: Optional[torch.FloatTensor] = None,
+ labels: Optional[torch.LongTensor] = None,
+ use_cache: Optional[bool] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ return_dict: Optional[bool] = None,
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
+ r"""
+ Args:
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
+
+ Returns:
+
+ Example:
+
+ ```python
+ >>> from transformers import AutoTokenizer, Qwen2ForCausalLM
+
+ >>> model = Qwen2ForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
+
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
+
+ >>> # Generate
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
+ ```"""
+
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
+ output_hidden_states = (
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
+ )
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
+
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
+ outputs = self.model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ past_key_values=past_key_values,
+ inputs_embeds=inputs_embeds,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ force_sp_output_gather=False,
+ )
+
+ hidden_states = outputs[0]
+ logits = self.lm_head(hidden_states)
+ logits = logits.float()
+ loss = None
+ if labels is not None:
+ loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
+
+ if not return_dict:
+ output = (logits,) + outputs[1:]
+ return (loss,) + output if loss is not None else output
+
+ return CausalLMOutputWithPast(
+ loss=loss,
+ logits=logits,
+ past_key_values=outputs.past_key_values,
+ hidden_states=outputs.hidden_states,
+ attentions=outputs.attentions,
+ )
+
+ return forward
diff --git a/colossalai/shardformer/policies/auto_policy.py b/colossalai/shardformer/policies/auto_policy.py
index a69053b2ff56..3d61af1e0aea 100644
--- a/colossalai/shardformer/policies/auto_policy.py
+++ b/colossalai/shardformer/policies/auto_policy.py
@@ -220,6 +220,16 @@ class PolicyLocation:
"transformers.models.qwen2.modeling_qwen2.Qwen2ForSequenceClassification": PolicyLocation(
file_name="qwen2", class_name="Qwen2ForSequenceClassificationPolicy"
),
+ # Qwen3
+ "transformers.models.qwen3.modeling_qwen3.Qwen3Model": PolicyLocation(
+ file_name="qwen3", class_name="Qwen3ModelPolicy"
+ ),
+ "transformers.models.qwen3.modeling_qwen3.Qwen3ForCausalLM": PolicyLocation(
+ file_name="qwen3", class_name="Qwen3ForCausalLMPolicy"
+ ),
+ "transformers.models.qwen3.modeling_qwen3.Qwen3ForSequenceClassification": PolicyLocation(
+ file_name="qwen3", class_name="Qwen3ForSequenceClassificationPolicy"
+ ),
# command
"transformers.models.cohere.modeling_cohere.CohereModel": PolicyLocation(
file_name="command", class_name="CommandModelPolicy"
diff --git a/colossalai/shardformer/policies/qwen3.py b/colossalai/shardformer/policies/qwen3.py
new file mode 100644
index 000000000000..e9cc9543278a
--- /dev/null
+++ b/colossalai/shardformer/policies/qwen3.py
@@ -0,0 +1,541 @@
+# Modifed from qwen2 policy
+from functools import partial
+from typing import Callable, Dict, List, Union
+
+import torch.nn as nn
+from torch import Tensor
+from torch.nn import Module
+from transformers.models.qwen3.modeling_qwen3 import (
+ Qwen3Attention,
+ Qwen3DecoderLayer,
+ Qwen3ForCausalLM,
+ Qwen3ForSequenceClassification,
+ Qwen3Model,
+)
+
+from colossalai.shardformer.layer import (
+ FusedRMSNorm,
+ Linear1D_Col,
+ Linear1D_Row,
+ LinearWithGradAccum,
+ PaddingEmbedding,
+ RMSNorm,
+ VocabParallelEmbedding1D,
+)
+
+from ..modeling.qwen3 import (
+ Qwen3PipelineForwards,
+ get_lm_forward_with_dist_cross_entropy,
+ get_qwen3_flash_attention_forward,
+ get_qwen3_model_forward_for_flash_attn,
+)
+from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
+
+__all__ = ["Qwen3Policy", "Qwen3ForCausalLMPolicy", "Qwen3ForSequenceClassificationPolicy"]
+
+
+class Qwen3Policy(Policy):
+ def __init__(self) -> None:
+ super().__init__()
+ import transformers
+ from packaging.version import Version
+
+ assert Version(transformers.__version__) >= Version(
+ "4.51.0"
+ ), "The Qwen3 model should run on a transformers version of 4.51.0 or higher."
+
+ def config_sanity_check(self):
+ pass
+
+ def preprocess(self):
+ self.tie_weight = self.tie_weight_check()
+ self.origin_attn_implement = self.model.config._attn_implementation
+ return self.model
+
+ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
+
+ policy = {}
+
+ embedding_cls = None
+ if self.shard_config.enable_tensor_parallelism:
+ embedding_cls = VocabParallelEmbedding1D
+ else:
+ if self.tie_weight:
+ embedding_cls = PaddingEmbedding
+ norm_cls = FusedRMSNorm if self.shard_config.enable_fused_normalization else RMSNorm
+
+ sp_mode = self.shard_config.sequence_parallelism_mode or None
+ sp_size = self.shard_config.sequence_parallel_size or None
+ sp_group = self.shard_config.sequence_parallel_process_group or None
+ sp_partial_derived = sp_mode in ["split_gather", "ring"]
+ if sp_mode == "all_to_all":
+ decoder_attribute_replacement = {
+ "num_heads": self.model.config.num_attention_heads // sp_size,
+ }
+ if getattr(self.model.config, "num_key_value_heads", False):
+ decoder_attribute_replacement["num_key_value_heads"] = self.model.config.num_key_value_heads // sp_size
+
+ policy[Qwen3Attention] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ )
+
+ use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
+
+ if self.shard_config.enable_tensor_parallelism:
+ assert (
+ self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of attention heads must be divisible by tensor parallel size."
+ if hasattr(self.model.config, "num_key_value_heads"):
+ assert (
+ self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
+ ), f"The number of key_value heads must be divisible by tensor parallel size."
+ decoder_attribute_replacement = {
+ "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
+ "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
+ }
+ if getattr(self.model.config, "num_key_value_heads", False):
+ decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
+ self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
+ )
+
+ policy[Qwen3DecoderLayer] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=Linear1D_Col,
+ kwargs=dict(
+ seq_parallel_mode=sp_mode,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=Linear1D_Col,
+ kwargs=dict(
+ seq_parallel_mode=sp_mode,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=Linear1D_Col,
+ kwargs=dict(
+ seq_parallel_mode=sp_mode,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.o_proj",
+ target_module=Linear1D_Row,
+ kwargs=dict(
+ seq_parallel_mode=sp_mode,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.gate_proj",
+ target_module=Linear1D_Col,
+ kwargs=dict(
+ seq_parallel_mode=sp_mode,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.up_proj",
+ target_module=Linear1D_Col,
+ kwargs=dict(
+ seq_parallel_mode=sp_mode,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.down_proj",
+ target_module=Linear1D_Row,
+ kwargs=dict(
+ seq_parallel_mode=sp_mode,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ ),
+ ],
+ )
+ elif use_zbv:
+ policy[Qwen3DecoderLayer] = ModulePolicyDescription(
+ attribute_replacement=decoder_attribute_replacement,
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="self_attn.q_proj",
+ target_module=LinearWithGradAccum,
+ kwargs=dict(
+ seq_parallel_mode=sp_mode,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.k_proj",
+ target_module=LinearWithGradAccum,
+ kwargs=dict(
+ seq_parallel_mode=sp_mode,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.v_proj",
+ target_module=LinearWithGradAccum,
+ kwargs=dict(
+ seq_parallel_mode=sp_mode,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ ),
+ SubModuleReplacementDescription(
+ suffix="self_attn.o_proj",
+ target_module=LinearWithGradAccum,
+ kwargs=dict(
+ seq_parallel_mode=sp_mode,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.gate_proj",
+ target_module=LinearWithGradAccum,
+ kwargs=dict(
+ seq_parallel_mode=sp_mode,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.up_proj",
+ target_module=LinearWithGradAccum,
+ kwargs=dict(
+ seq_parallel_mode=sp_mode,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ ),
+ SubModuleReplacementDescription(
+ suffix="mlp.down_proj",
+ target_module=LinearWithGradAccum,
+ kwargs=dict(
+ seq_parallel_mode=sp_mode,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ ),
+ ],
+ )
+
+ if embedding_cls is not None:
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="embed_tokens",
+ target_module=embedding_cls,
+ kwargs=(
+ {
+ "make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
+ "fp8_communication": self.shard_config.fp8_communication,
+ }
+ if self.shard_config.enable_tensor_parallelism
+ else {"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by}
+ ),
+ ),
+ policy=policy,
+ target_key=Qwen3Model,
+ )
+
+ # optimization configuration
+ self.append_or_create_submodule_replacement(
+ description=[
+ SubModuleReplacementDescription(
+ suffix="input_layernorm",
+ target_module=norm_cls,
+ kwargs={"sp_partial_derived": sp_partial_derived},
+ ),
+ SubModuleReplacementDescription(
+ suffix="post_attention_layernorm",
+ target_module=norm_cls,
+ kwargs={"sp_partial_derived": sp_partial_derived},
+ ),
+ ],
+ policy=policy,
+ target_key=Qwen3DecoderLayer,
+ )
+
+ self.append_or_create_submodule_replacement(
+ description=SubModuleReplacementDescription(
+ suffix="norm",
+ target_module=norm_cls,
+ kwargs={"sp_partial_derived": sp_partial_derived},
+ ),
+ policy=policy,
+ target_key=Qwen3Model,
+ )
+
+ if self.shard_config.enable_flash_attention or self.shard_config.enable_sequence_parallelism:
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_qwen3_flash_attention_forward(self.shard_config, sp_mode, sp_size, sp_group),
+ },
+ policy=policy,
+ target_key=Qwen3Attention,
+ )
+ if self.pipeline_stage_manager is None:
+ # replace qwen3 model forward method
+ self.append_or_create_method_replacement(
+ description={
+ "forward": get_qwen3_model_forward_for_flash_attn(
+ self.shard_config, sp_mode, sp_size, sp_group
+ ),
+ },
+ policy=policy,
+ target_key=Qwen3Model,
+ )
+
+ return policy
+
+ def postprocess(self):
+ return self.model
+
+ def set_pipeline_forward(self, model_cls: nn.Module, new_forward: Callable, policy: Dict) -> None:
+ """If under pipeline parallel setting, replacing the original forward method of huggingface
+ to customized forward method, and add this changing to policy."""
+ if self.pipeline_stage_manager is None:
+ return
+
+ stage_manager = self.pipeline_stage_manager
+ if self.model.__class__.__name__ == "Qwen3Model":
+ module = self.model
+ else:
+ module = self.model.model
+
+ if stage_manager.is_interleave:
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ stage_manager.stage_indices = stage_manager.get_stage_index(layers_per_stage)
+ method_replacement = {
+ "forward": partial(new_forward, stage_manager=stage_manager, shard_config=self.shard_config)
+ }
+
+ else:
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ stage_index = stage_manager.get_stage_index(layers_per_stage)
+ method_replacement = {
+ "forward": partial(
+ new_forward, stage_manager=stage_manager, stage_index=stage_index, shard_config=self.shard_config
+ )
+ }
+ self.append_or_create_method_replacement(
+ description=method_replacement, policy=policy, target_key=model_cls
+ )
+
+ self.append_or_create_method_replacement(description=method_replacement, policy=policy, target_key=model_cls)
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ assert self.pipeline_stage_manager is not None
+
+ if self.model.__class__.__name__ == "Qwen3Model":
+ module = self.model
+ else:
+ module = self.model.model
+
+ stage_manager = self.pipeline_stage_manager
+
+ held_layers = []
+ held_layers.append(module.rotary_emb)
+ if stage_manager.is_interleave:
+ assert stage_manager.num_model_chunks is not None
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ stage_indices = stage_manager.get_stage_index(layers_per_stage)
+ if stage_manager.is_first_stage(ignore_chunk=True):
+ held_layers.append(module.embed_tokens)
+ for start_idx, end_idx in stage_indices:
+ held_layers.extend(module.layers[start_idx:end_idx])
+ if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
+ not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
+ ):
+ held_layers.append(module.norm)
+
+ else:
+ layers_per_stage = stage_manager.distribute_layers(len(module.layers))
+ if stage_manager.is_first_stage():
+ held_layers.append(module.embed_tokens)
+ start_idx, end_idx = stage_manager.get_stage_index(layers_per_stage)
+ held_layers.extend(module.layers[start_idx:end_idx])
+ if stage_manager.is_last_stage():
+ held_layers.append(module.norm)
+
+ return held_layers
+
+
+class Qwen3ModelPolicy(Qwen3Policy):
+ def module_policy(self):
+ policy = super().module_policy()
+
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(
+ model_cls=Qwen3Model, new_forward=Qwen3PipelineForwards.qwen3_model_forward, policy=policy
+ )
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ held_layers = super().get_held_layers()
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in Qwen3 model"""
+ return []
+
+
+class Qwen3ForCausalLMPolicy(Qwen3Policy):
+ def module_policy(self):
+ policy = super().module_policy()
+ setattr(self.shard_config, "causal_lm", True)
+ use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
+
+ if self.shard_config.enable_tensor_parallelism:
+ # add a new item for casual lm
+ new_item = {
+ Qwen3ForCausalLM: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=Linear1D_Col,
+ kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
+ )
+ ],
+ method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
+ )
+ }
+ policy.update(new_item)
+ elif use_zbv:
+ # add a new item for casual lm
+ new_item = {
+ Qwen3ForCausalLM: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="lm_head",
+ target_module=LinearWithGradAccum,
+ kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
+ )
+ ],
+ method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
+ )
+ }
+ policy.update(new_item)
+
+ if self.pipeline_stage_manager:
+ # set None as default
+ self.set_pipeline_forward(
+ model_cls=Qwen3ForCausalLM, new_forward=Qwen3PipelineForwards.qwen3_for_causal_lm_forward, policy=policy
+ )
+
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_interleave:
+ if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
+ not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
+ ):
+ held_layers.append(self.model.lm_head)
+ else:
+ if stage_manager.is_last_stage(ignore_chunk=True):
+ held_layers.append(self.model.lm_head)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ qwen3_model = self.model.model
+ if self.pipeline_stage_manager and self.pipeline_stage_manager.num_stages > 1:
+ if (
+ id(qwen3_model.embed_tokens.weight) == id(self.model.lm_head.weight)
+ and self.pipeline_stage_manager.num_stages > 1
+ ):
+ # tie weights
+ return [
+ {
+ 0: qwen3_model.embed_tokens.weight,
+ self.pipeline_stage_manager.num_stages - 1: self.model.lm_head.weight,
+ }
+ ]
+ return []
+
+
+class Qwen3ForSequenceClassificationPolicy(Qwen3Policy):
+ def module_policy(self):
+ policy = super().module_policy()
+ use_zbv = self.pipeline_stage_manager is not None and self.pipeline_stage_manager.use_zbv
+ if self.shard_config.enable_tensor_parallelism:
+ # add a new item for sequence classification
+ new_item = {
+ Qwen3ForSequenceClassification: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="score",
+ target_module=Linear1D_Col,
+ kwargs=dict(
+ gather_output=True,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ )
+ ]
+ )
+ }
+ policy.update(new_item)
+ elif use_zbv:
+ new_item = {
+ Qwen3ForSequenceClassification: ModulePolicyDescription(
+ sub_module_replacement=[
+ SubModuleReplacementDescription(
+ suffix="score",
+ target_module=LinearWithGradAccum,
+ kwargs=dict(
+ gather_output=True,
+ fp8_communication=self.shard_config.fp8_communication,
+ use_zbv=use_zbv,
+ ),
+ )
+ ]
+ )
+ }
+ policy.update(new_item)
+ # to be confirmed
+ if self.pipeline_stage_manager:
+ self.set_pipeline_forward(
+ model_cls=Qwen3ForSequenceClassification,
+ new_forward=Qwen3PipelineForwards.qwen3_for_sequence_classification_forward,
+ policy=policy,
+ )
+ return policy
+
+ def get_held_layers(self) -> List[Module]:
+ """Get pipeline layers for current stage."""
+ stage_manager = self.pipeline_stage_manager
+ held_layers = super().get_held_layers()
+ if stage_manager.is_interleave:
+ if (stage_manager.use_zbv and stage_manager.is_first_stage(ignore_chunk=True)) or (
+ not stage_manager.use_zbv and stage_manager.is_last_stage(ignore_chunk=True)
+ ):
+ held_layers.append(self.model.score)
+ else:
+ if stage_manager.is_last_stage(ignore_chunk=True):
+ held_layers.append(self.model.score)
+ return held_layers
+
+ def get_shared_params(self) -> List[Dict[int, Tensor]]:
+ """No shared params in Qwen3 for sequence classification model"""
+ return []
diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py
index 4adc386192d3..3127aea1b283 100644
--- a/tests/kit/model_zoo/transformers/__init__.py
+++ b/tests/kit/model_zoo/transformers/__init__.py
@@ -13,6 +13,7 @@
from .mixtral import *
from .opt import *
from .qwen2 import *
+from .qwen3 import *
from .sam import *
from .t5 import *
from .vit import *
diff --git a/tests/kit/model_zoo/transformers/qwen3.py b/tests/kit/model_zoo/transformers/qwen3.py
new file mode 100644
index 000000000000..97d4bd79cdf4
--- /dev/null
+++ b/tests/kit/model_zoo/transformers/qwen3.py
@@ -0,0 +1,121 @@
+import torch
+import transformers
+
+from ..registry import ModelAttribute, model_zoo
+
+try:
+ from transformers import Qwen3Config
+
+ HAS_QWEN3 = True
+except ImportError:
+ HAS_QWEN3 = False
+
+if HAS_QWEN3:
+ # ===============================
+ # Register Qwen3
+ # ===============================
+
+ def data_gen():
+ # the input ids are corresponding to the sentence
+ # 'Hello, my dog is cute'
+ #
+ # the code is give below:
+ # -----------------------------------
+ # from transformers import AutoTokenizer
+ # tokenizer = AutoTokenizer.from_pretrained('Qwen/Qwen3-4B')
+ # input = "This is a test sentence. This is a test sentence. This is a test sentence. This is a test sentence."
+ # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
+ # -----------------------------------
+
+ # NOTE: due to sp convention, need to be a multiple of 4
+ input_ids = torch.tensor(
+ [
+ [
+ 1986,
+ 374,
+ 264,
+ 1273,
+ 11652,
+ 13,
+ 1096,
+ 374,
+ 264,
+ 1273,
+ 11652,
+ 13,
+ 1096,
+ 374,
+ 264,
+ 1273,
+ 11652,
+ 13,
+ 1096,
+ 374,
+ 264,
+ 1273,
+ 11652,
+ 13,
+ ]
+ ],
+ dtype=torch.long,
+ )
+ attention_mask = torch.ones(input_ids.shape, dtype=torch.long)
+ return dict(input_ids=input_ids, attention_mask=attention_mask)
+
+ # label is needed for causal lm
+ def data_gen_for_causal_lm():
+ data = data_gen()
+ labels = data["input_ids"].clone()
+ data["labels"] = labels
+ return data
+
+ # transform the output to a dict
+ output_transform_fn = lambda x: x
+
+ # function to get the loss
+ loss_fn = lambda output: output["last_hidden_state"].mean()
+ loss_fn_for_causal_lm = lambda output: output["loss"]
+ loss_fn_for_seq_classification = lambda output: output["logits"].mean()
+
+ config = Qwen3Config(
+ hidden_size=128,
+ intermediate_size=256,
+ max_window_layers=4,
+ num_attention_heads=16,
+ num_hidden_layers=4,
+ num_key_value_heads=16,
+ attn_implementation="sdpa", # for tests on fp32
+ sliding_window=None, # not supported by sdpa
+ use_cache=False,
+ )
+
+ config.pad_token_id = 0
+
+ # register the following models
+ # transformers.Qwen3Model,
+ # transformers.Qwen3ForCausalLM,
+ # transformers.Qwen3ForSequenceClassification,
+ model_zoo.register(
+ name="transformers_qwen3",
+ model_fn=lambda: transformers.Qwen3Model(config),
+ data_gen_fn=data_gen,
+ output_transform_fn=output_transform_fn,
+ loss_fn=loss_fn,
+ model_attribute=ModelAttribute(has_control_flow=True),
+ )
+ model_zoo.register(
+ name="transformers_qwen3_for_causal_lm",
+ model_fn=lambda: transformers.Qwen3ForCausalLM(config),
+ data_gen_fn=data_gen_for_causal_lm,
+ output_transform_fn=output_transform_fn,
+ loss_fn=loss_fn_for_causal_lm,
+ model_attribute=ModelAttribute(has_control_flow=True),
+ )
+ model_zoo.register(
+ name="transformers_qwen3_for_sequence_classification",
+ model_fn=lambda: transformers.Qwen3ForSequenceClassification(config),
+ data_gen_fn=data_gen,
+ output_transform_fn=output_transform_fn,
+ loss_fn=loss_fn_for_seq_classification,
+ model_attribute=ModelAttribute(has_control_flow=True),
+ )
diff --git a/tests/test_shardformer/test_model/test_shard_qwen3.py b/tests/test_shardformer/test_model/test_shard_qwen3.py
new file mode 100644
index 000000000000..9670a5999e8c
--- /dev/null
+++ b/tests/test_shardformer/test_model/test_shard_qwen3.py
@@ -0,0 +1,302 @@
+import pytest
+import torch
+import transformers
+
+import colossalai
+from colossalai.logging import disable_existing_loggers
+from colossalai.shardformer.layer.utils import Randomizer
+from colossalai.tensor.d_tensor.api import clear_layout_converter
+from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn
+from tests.kit.model_zoo import model_zoo
+from tests.test_shardformer.test_model._utils import (
+ build_model_from_hybrid_plugin,
+ check_all_grad_tensors,
+ check_loss,
+ check_output_hidden_state,
+ check_weight,
+ get_grad_tensors_for_check,
+ run_forward_backward_with_hybrid_plugin,
+ unwrap_model,
+)
+
+
+def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config):
+ org_model, org_optimizer, sharded_model, sharded_optimizer, criterion, booster = build_model_from_hybrid_plugin(
+ model_fn, loss_fn, test_config
+ )
+
+ org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
+ org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
+ )
+
+ stage_manager = booster.plugin.stage_manager
+ tp_group = booster.plugin.tp_group
+
+ # unwrap model
+ qwen3_model = unwrap_model(org_model, "Qwen3Model", "model")
+ shard_qwen3_model = unwrap_model(sharded_model, "Qwen3Model", "model")
+
+ row_layer_for_check = ["layers[0].self_attn.q_proj", "embed_tokens"]
+ col_layer_for_check = ["layers[0].self_attn.o_proj"]
+
+ # Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
+ grads_to_check = {}
+ if (stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True)) and booster.plugin.zero_stage == 0:
+ if test_config["precision"] == "fp32":
+ atol, rtol = 1e-6, 1e-4
+ else:
+ atol, rtol = 5e-3, 5e-3
+ row_layer_grads = get_grad_tensors_for_check(
+ qwen3_model, shard_qwen3_model, row_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=0, verbose=False
+ )
+ col_layer_grads = get_grad_tensors_for_check(
+ qwen3_model, shard_qwen3_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
+ )
+ grads_to_check.update(col_layer_grads)
+ grads_to_check.update(row_layer_grads)
+
+ # optimizer executes step
+ org_optimizer.step()
+ sharded_optimizer.step()
+
+ # check last hidden state & loss
+ if stage_manager is None or stage_manager.is_last_stage(ignore_chunk=True):
+ if test_config["precision"] == "fp32":
+ atol, rtol = 1e-5, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
+
+ if org_model.__class__.__name__ == "Qwen3Model":
+ check_output_hidden_state(org_output, sharded_output, stage_manager, atol=atol, rtol=rtol)
+
+ check_loss(org_loss, sharded_loss, atol=atol, rtol=rtol)
+
+ # check weights
+ if stage_manager is None or stage_manager.is_first_stage(ignore_chunk=True):
+ if test_config["precision"] == "fp32":
+ atol, rtol = 1e-3, 1e-3
+ else:
+ atol, rtol = 5e-3, 5e-3
+ check_weight(
+ qwen3_model, shard_qwen3_model, col_layer_for_check, tp_group, atol=atol, rtol=rtol, dim=1, verbose=False
+ )
+
+ # check grads
+ check_all_grad_tensors(grads_to_check)
+
+ torch.cuda.empty_cache()
+
+
+@parameterize(
+ "test_config",
+ [
+ {
+ "tp_size": 2,
+ "pp_size": 2,
+ "sp_size": 2,
+ "num_microbatches": 2,
+ "enable_sequence_parallelism": True,
+ "sequence_parallelism_mode": "split_gather",
+ "enable_flash_attention": True,
+ "use_lazy_init": True,
+ "zero_stage": 1,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ { # Ulysess + Flash attention
+ "tp_size": 1,
+ "pp_size": 2,
+ "sp_size": 2,
+ "num_microbatches": 2,
+ "enable_sequence_parallelism": True,
+ "sequence_parallelism_mode": "all_to_all",
+ "enable_flash_attention": True,
+ "use_lazy_init": True,
+ "zero_stage": 1,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ {
+ "tp_size": 2,
+ "pp_size": 2,
+ "num_microbatches": 2,
+ "enable_all_optimization": True,
+ "use_lazy_init": True,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ {
+ "tp_size": 1,
+ "pp_size": 2,
+ "num_microbatches": 4,
+ "use_lazy_init": False,
+ "precision": "fp32",
+ },
+ {
+ "tp_size": 4,
+ "pp_size": 1,
+ "enable_all_optimization": True,
+ "use_lazy_init": False,
+ "precision": "fp32",
+ },
+ {
+ "tp_size": 1,
+ "pp_size": 4,
+ "num_microbatches": 4,
+ "enable_all_optimization": False,
+ "use_lazy_init": False,
+ "precision": "fp32",
+ },
+ {"tp_size": 2, "pp_size": 1, "enable_all_optimization": True, "use_lazy_init": False, "precision": "fp32"},
+ {
+ "tp_size": 2,
+ "pp_size": 1,
+ "enable_all_optimization": True,
+ "use_lazy_init": True,
+ "zero_stage": 2,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ {
+ "tp_size": 2,
+ "pp_size": 2,
+ "sp_size": 2,
+ "num_microbatches": 2,
+ "enable_sequence_parallelism": True,
+ "sequence_parallelism_mode": "ring",
+ "enable_flash_attention": True,
+ "use_lazy_init": True,
+ "zero_stage": 1,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ {
+ "tp_size": 1,
+ "pp_size": 1,
+ "sp_size": 2,
+ "num_microbatches": 1,
+ "enable_sequence_parallelism": True,
+ "sequence_parallelism_mode": "all_to_all",
+ "use_lazy_init": True,
+ "zero_stage": 1,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ {
+ "tp_size": 4,
+ "pp_size": 1,
+ "num_microbatches": 1,
+ "enable_sequence_parallelism": True,
+ "sequence_parallelism_mode": "split_gather",
+ "enable_flash_attention": False,
+ "use_lazy_init": True,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ {
+ "tp_size": 1,
+ "pp_size": 2,
+ "num_microbatches": 2,
+ "enable_all_optimization": True,
+ "use_lazy_init": True,
+ "zero_stage": 1,
+ "precision": "fp16",
+ "initial_scale": 1,
+ },
+ ],
+)
+def run_qwen3_test(test_config):
+ sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen3")
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ try:
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+ except Exception as e:
+ print(f"Failed config: {test_config}")
+ raise e
+ clear_layout_converter()
+ Randomizer.reset_index()
+ torch.cuda.empty_cache()
+
+
+@parameterize(
+ "test_config",
+ [
+ {
+ "tp_size": 2,
+ "pp_size": 2,
+ "num_microbatches": 4,
+ "enable_all_optimization": False,
+ "use_lazy_init": False,
+ "precision": "fp32",
+ "initial_scale": 1,
+ },
+ {
+ "tp_size": 2,
+ "pp_size": 2,
+ "num_microbatches": 4,
+ "enable_all_optimization": False,
+ "use_lazy_init": False,
+ "precision": "fp16",
+ "zero_stage": 1,
+ "initial_scale": 1,
+ },
+ {
+ "tp_size": 2,
+ "pp_size": 2,
+ "pp_style": "interleaved",
+ "num_model_chunks": 2,
+ "num_microbatches": 4,
+ "enable_all_optimization": False,
+ "precision": "fp16",
+ "zero_stage": 1,
+ "initial_scale": 1,
+ },
+ ],
+)
+def run_qwen3_3d_test(test_config):
+ sub_model_zoo = model_zoo.get_sub_registry("transformers_qwen3")
+
+ for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
+ try:
+ check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config)
+ except Exception as e:
+ print(f"Failed config: {test_config}")
+ raise e
+
+ clear_layout_converter()
+ Randomizer.reset_index()
+ torch.cuda.empty_cache()
+
+
+def check_qwen3(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_qwen3_test()
+
+
+def check_qwen3_3d(rank, world_size, port):
+ disable_existing_loggers()
+ colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
+ run_qwen3_3d_test()
+
+
+@pytest.mark.skipif(transformers.__version__ < "4.51.0", reason="Requires transformers version 4.51.0 or later")
+@pytest.mark.dist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_qwen3():
+ spawn(check_qwen3, 4)
+
+
+@pytest.mark.skipif(transformers.__version__ < "4.51.0", reason="Requires transformers version 4.51.0 or later")
+@pytest.mark.largedist
+@rerun_if_address_is_in_use()
+@clear_cache_before_run()
+def test_qwen3_3d():
+ spawn(check_qwen3_3d, 8)
+
+
+if __name__ == "__main__":
+ test_qwen3()
+ test_qwen3_3d()