Skip to content

Commit d097224

Browse files
authored
[feat] support qwen3 in shardformer
1 parent 97f4bee commit d097224

File tree

9 files changed

+1829
-42
lines changed

9 files changed

+1829
-42
lines changed

.github/workflows/build_on_pr.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,10 @@ jobs:
138138
cp -p -r /github/home/cuda_ext_cache/* /__w/ColossalAI/ColossalAI/
139139
fi
140140
141+
- name: Install flash-attention
142+
run: |
143+
pip install flash-attn==2.7.4.post1 --no-build-isolation
144+
141145
- name: Install Colossal-AI
142146
run: |
143147
BUILD_EXT=1 pip install -v -e .

colossalai/shardformer/modeling/qwen2.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -4,31 +4,23 @@
44
import torch
55
from torch import nn
66
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7+
from transformers.modeling_attn_mask_utils import (
8+
_prepare_4d_causal_attention_mask,
9+
_prepare_4d_causal_attention_mask_for_sdpa,
10+
)
711
from transformers.modeling_outputs import (
812
BaseModelOutputWithPast,
913
CausalLMOutputWithPast,
1014
SequenceClassifierOutputWithPast,
1115
)
12-
13-
try:
14-
from transformers.modeling_attn_mask_utils import (
15-
_prepare_4d_causal_attention_mask,
16-
_prepare_4d_causal_attention_mask_for_sdpa,
17-
)
18-
from transformers.models.qwen2.modeling_qwen2 import (
19-
Qwen2Attention,
20-
Qwen2ForCausalLM,
21-
Qwen2ForSequenceClassification,
22-
Qwen2Model,
23-
apply_rotary_pos_emb,
24-
repeat_kv,
25-
)
26-
except ImportError:
27-
Qwen2Model = "Qwen2Model"
28-
Qwen2ForCausalLM = "Qwen2ForCausalLM"
29-
Qwen2Attention = "Qwen2Attention"
30-
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
31-
16+
from transformers.models.qwen2.modeling_qwen2 import (
17+
Qwen2Attention,
18+
Qwen2ForCausalLM,
19+
Qwen2ForSequenceClassification,
20+
Qwen2Model,
21+
apply_rotary_pos_emb,
22+
repeat_kv,
23+
)
3224
from transformers.utils import logging
3325

3426
from colossalai.pipeline.stage_manager import PipelineStageManager
@@ -434,7 +426,6 @@ def qwen2_for_sequence_classification_forward(
434426
logits = self.score(hidden_states)
435427

436428
if self.config.pad_token_id is None and batch_size != 1:
437-
print(self.config.pad_token_id)
438429
raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
439430
if self.config.pad_token_id is None:
440431
sequence_lengths = -1

0 commit comments

Comments
 (0)