Skip to content

Commit bbb2c21

Browse files
authored
[shardformer] fix chatglm implementation (#5644)
* [shardformer] fix chatglm policy * [shardformer] fix chatglm flash attn * [shardformer] update readme * [shardformer] fix chatglm init * [shardformer] fix chatglm test * [pipeline] fix chatglm merge batch
1 parent 5d88ef1 commit bbb2c21

File tree

11 files changed

+193
-117
lines changed

11 files changed

+193
-117
lines changed

colossalai/pipeline/schedule/one_f_one_b.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from torch.utils._pytree import tree_map
88

99
from colossalai.accelerator import get_accelerator
10-
from colossalai.interface import OptimizerWrapper
10+
from colossalai.interface import ModelWrapper, OptimizerWrapper
1111
from colossalai.pipeline.p2p import PipelineP2PCommunication, create_send_metadata
1212
from colossalai.pipeline.stage_manager import PipelineStageManager
1313
from colossalai.utils import get_current_device
@@ -327,7 +327,10 @@ def run_forward_only(
327327
self.send_forward(output_obj)
328328

329329
if outputs is not None:
330-
outputs = merge_batch(outputs)
330+
if isinstance(model, ModelWrapper):
331+
model = model.unwrap()
332+
batch_size_dim = getattr(model, "batch_size_dim", 0)
333+
outputs = merge_batch(outputs, batch_size_dim)
331334
return {"loss": accum_loss, "outputs": outputs}
332335

333336
def run_forward_backward(
@@ -410,7 +413,10 @@ def run_forward_backward(
410413
assert all(len(v) == 0 for v in input_objs) and all(len(v) == 0 for v in output_objs)
411414

412415
if outputs is not None:
413-
outputs = merge_batch(outputs)
416+
if isinstance(model, ModelWrapper):
417+
model = model.unwrap()
418+
batch_size_dim = getattr(model, "batch_size_dim", 0)
419+
outputs = merge_batch(outputs, batch_size_dim)
414420
return {"loss": accum_loss, "outputs": outputs}
415421

416422
def forward_backward_step(

colossalai/shardformer/README.md

Lines changed: 79 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -114,30 +114,30 @@ We will follow this roadmap to develop Shardformer:
114114
- [x] Unit Testing
115115
- [ ] Policy Implementation
116116

117-
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
118-
| :------: | :-----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
119-
| bert | [] | [] | [] | [] | [] | [] | [] | [] | [] |
120-
| t5 | [] | [] | [] | [] | [] | [] | [] | [ ] | [ ] |
121-
| llama V1/V2 | [] | [] | [] | [] | [] | [] | [] | [ ] | [ ] |
122-
| gpt2 | [] | [] | [] | [] | [] | [] | [] | [] | [] |
123-
| opt | [] | [] | [] | [] | [] | [] | [] | [ ] | [ ] |
124-
| bloom | [] | [] | [] | [] | [] | [] | [] | [] | [] |
125-
| chatglm2 | [] | [] | [] | [] | [] | [] | [] | [] | [] |
126-
| vit | [] | [] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
127-
| whisper | [] | [] | [] | [] | [] | [ ] | [] | [ ] | [ ] |
128-
| sam | [] | [ ] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
129-
| blip2 | [] | [ ] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
130-
| falcon | [] | [] | [] | [] | [] | [ ] | [] | [ ] | [ ] |
131-
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
132-
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
133-
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
134-
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
135-
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
136-
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
137-
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
138-
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
139-
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
140-
| mistral | [] | [ ] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
117+
| model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
118+
|:-----------:|:---------------:|:-----------------:|:-------------------:|:-------:|:-----------:|:------------------:|:---------------:|:-----------------:|:-------:|
119+
| bert | [] | [] | [] | [] | [] | [] | [] | [] | [] |
120+
| t5 | [] | [] | [] | [] | [] | [] | [] | [ ] | [ ] |
121+
| llama V1/V2 | [] | [] | [] | [] | [] | [] | [] | [ ] | [ ] |
122+
| gpt2 | [] | [] | [] | [] | [] | [] | [] | [] | [] |
123+
| opt | [] | [] | [] | [] | [] | [] | [] | [ ] | [ ] |
124+
| bloom | [] | [] | [] | [] | [] | [] | [] | [] | [] |
125+
| chatglm2 | [] | [] | [] | [] | [] | [] | [] | [] | [] |
126+
| vit | [] | [] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
127+
| whisper | [] | [] | [] | [] | [] | [ ] | [] | [ ] | [ ] |
128+
| sam | [] | [ ] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
129+
| blip2 | [] | [ ] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
130+
| falcon | [] | [] | [] | [] | [] | [ ] | [] | [ ] | [ ] |
131+
| roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
132+
| albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
133+
| ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
134+
| gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
135+
| gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
136+
| beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
137+
| swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
138+
| swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
139+
| qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
140+
| mistral | [] | [ ] | [ ] | [] | [] | [] | [] | [ ] | [ ] |
141141

142142

143143
## 💡 API Design
@@ -391,6 +391,43 @@ _POLICY_LIST = {
391391
}
392392
```
393393

394+
#### How to support those models in huggingface model hub but not in the transformers library
395+
396+
There are two cases:
397+
398+
1. the modeling file is in the `transformers` library but the model weight is not in the `transformers` library. E.g. model structure of "01-ai/Yi-34B" is the same as LLaMA but the weight is not in the `transformers` library. In this case, we should support llama as usual and Yi-34B is also supported by the llama policy. We do not need to add a new policy for Yi-34B.
399+
2. the modeling file is not in the `transformers` library, such as the "THUDM/chatglm2-6b".
400+
401+
Take "THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the `shardformer`.
402+
403+
Unlike llama which is in `transformers` library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself.
404+
405+
E.g. for llama:
406+
```python
407+
policy[LlamaDecoderLayer] = ModulePolicyDescription(...)
408+
```
409+
410+
for chatglm2:
411+
```python
412+
policy["GLMBlock"] = ModulePolicyDescription(...)
413+
```
414+
415+
Then when registering such models in the autopolicy, we should follow below format:
416+
```python
417+
"transformers_modules.<modeling_filename>.<class_name>": PolicyLocation(
418+
file_name="<policy_filename>", class_name="<policy_class_name>"
419+
)
420+
```
421+
422+
As for chatglm2 model, it should be:
423+
```python
424+
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
425+
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
426+
)
427+
```
428+
429+
When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy.
430+
394431
### Write Your Unit Testing
395432

396433
This section serves as the guideline for testing the `shardformer` module.
@@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate
424461
We set the batch size to 4, the number of attention heads to 8, and the head dimension to 64. 'N_CTX' refers to the sequence length.
425462

426463
In the case of using 2 GPUs, the training times are as follows.
427-
| N_CTX | org_model | shard_model |
428-
| :------: | :-----: | :-----: |
429-
| 256 | 11.2ms | 17.2ms |
430-
| 512 | 9.8ms | 19.5ms |
431-
| 1024 | 19.6ms | 18.9ms |
432-
| 2048 | 46.6ms | 30.8ms |
433-
| 4096 | 160.5ms | 90.4ms |
464+
| N_CTX | org_model | shard_model |
465+
|:-----:|:---------:|:-----------:|
466+
| 256 | 11.2ms | 17.2ms |
467+
| 512 | 9.8ms | 19.5ms |
468+
| 1024 | 19.6ms | 18.9ms |
469+
| 2048 | 46.6ms | 30.8ms |
470+
| 4096 | 160.5ms | 90.4ms |
434471

435472

436473
<p align="center">
@@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows.
440477

441478
In the case of using 4 GPUs, the training times are as follows.
442479

443-
| N_CTX | org_model | shard_model |
444-
| :------: | :-----: | :-----: |
445-
| 256 | 10.0ms | 21.1ms |
446-
| 512 | 11.5ms | 20.2ms |
447-
| 1024 | 22.1ms | 20.6ms |
448-
| 2048 | 46.9ms | 24.8ms |
449-
| 4096 | 160.4ms | 68.0ms |
480+
| N_CTX | org_model | shard_model |
481+
|:-----:|:---------:|:-----------:|
482+
| 256 | 10.0ms | 21.1ms |
483+
| 512 | 11.5ms | 20.2ms |
484+
| 1024 | 22.1ms | 20.6ms |
485+
| 2048 | 46.9ms | 24.8ms |
486+
| 4096 | 160.4ms | 68.0ms |
450487

451488

452489

@@ -475,10 +512,10 @@ warmup_fraction = 0.03
475512

476513

477514
| accuracy | f1 | loss | GPU number | model sharded |
478-
| :------: | :-----: | :-----: | :--------: | :---------: |
479-
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
480-
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
481-
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
515+
|:--------:|:-------:|:-------:|:----------:|:-------------:|
516+
| 0.82971 | 0.87713 | 0.23194 | 4 | True |
517+
| 0.83797 | 0.88006 | 0.22683 | 2 | True |
518+
| 0.84521 | 0.88700 | 0.21822 | 1 | False |
482519

483520

484521
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.

colossalai/shardformer/layer/normalization.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -281,19 +281,16 @@ def from_native_module(module: nn.Module, sp_partial_derived: bool = False, *arg
281281
)
282282

283283
LazyInitContext.materialize(module)
284-
# to check if it is huggingface LlamaRMSNorm or MistralRMSNorm
285-
if module.__class__.__name__ in ["LlamaRMSNorm", "MistralRMSNorm"]:
286-
normalized_shape = module.weight.shape[0]
287-
eps = module.variance_epsilon
288-
elementwise_affine = True
289-
else:
290-
# get the attributes of the module
291-
normalized_shape = module.normalized_shape
292-
eps = module.eps
293-
elementwise_affine = module.elementwise_affine
284+
285+
# try to get normalized_shape, eps, elementwise_affine from the module
286+
normalized_shape = getattr(module, "normalized_shape", module.weight.shape[0])
287+
eps = module.variance_epsilon if hasattr(module, "variance_epsilon") else module.eps
288+
elementwise_affine = getattr(module, "elementwise_affine", True)
294289

295290
rmsnorm = FusedRMSNormWithHook(
296-
normalized_shape=normalized_shape, eps=eps, elementwise_affine=elementwise_affine
291+
normalized_shape=normalized_shape,
292+
eps=eps,
293+
elementwise_affine=elementwise_affine,
297294
)
298295

299296
rmsnorm.weight = module.weight

colossalai/shardformer/modeling/chatglm2.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from colossalai.shardformer import ShardConfig
1313
from colossalai.shardformer.layer import AttnMaskType, ColoAttention
1414
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
15-
from colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm import ChatGLMForConditionalGeneration, ChatGLMModel
1615

1716

1817
def get_flash_core_attention_forward():
@@ -31,7 +30,12 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_
3130
device=query_layer.device,
3231
)
3332
temp_mask = (
34-
torch.ones(query_layer.shape[2], key_layer.shape[2], dtype=torch.bool, device=query_layer.device)
33+
torch.ones(
34+
query_layer.shape[2],
35+
key_layer.shape[2],
36+
dtype=torch.bool,
37+
device=query_layer.device,
38+
)
3539
.tril(diagonal=0)
3640
.expand(query_layer.shape[0], 1, -1, -1)
3741
)
@@ -49,6 +53,7 @@ def forward(self: CoreAttention, query_layer, key_layer, value_layer, attention_
4953
attention_mask=attn_bias,
5054
attention_mask_type=attention_mask_type,
5155
dropout_p=dropout_p,
56+
scale=1.0 / self.norm_factor,
5257
)
5358
context_layer = context_layer.permute(2, 0, 1, 3)
5459
new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,)
@@ -115,7 +120,7 @@ class ChatGLMPipelineForwards:
115120

116121
@staticmethod
117122
def chatglm_model_forward(
118-
self: ChatGLMModel,
123+
self: "ChatGLMModel",
119124
input_ids,
120125
position_ids: Optional[torch.Tensor] = None,
121126
attention_mask: Optional[torch.BoolTensor] = None,
@@ -194,7 +199,9 @@ def chatglm_model_forward(
194199
if shard_config and shard_config.enable_sequence_parallelism:
195200
if shard_config.sequence_parallelism_mode == "split_gather":
196201
hidden_states = split_forward_gather_backward(
197-
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
202+
hidden_states,
203+
dim=0,
204+
process_group=shard_config.tensor_parallel_process_group,
198205
)
199206
for idx in range(start_idx, end_idx):
200207
layer = self.encoder._get_layer(idx)
@@ -224,7 +231,9 @@ def chatglm_model_forward(
224231
if shard_config and shard_config.enable_sequence_parallelism:
225232
if shard_config.sequence_parallelism_mode == "split_gather":
226233
hidden_states = gather_forward_split_backward(
227-
hidden_states, dim=0, process_group=shard_config.tensor_parallel_process_group
234+
hidden_states,
235+
dim=0,
236+
process_group=shard_config.tensor_parallel_process_group,
228237
)
229238
if output_hidden_states:
230239
all_hidden_states = all_hidden_states + (hidden_states,)
@@ -254,7 +263,7 @@ def chatglm_model_forward(
254263

255264
@staticmethod
256265
def chatglm_for_conditional_generation_forward(
257-
self: ChatGLMForConditionalGeneration,
266+
self: "ChatGLMForConditionalGeneration",
258267
input_ids: Optional[torch.Tensor] = None,
259268
position_ids: Optional[torch.Tensor] = None,
260269
attention_mask: Optional[torch.Tensor] = None,

colossalai/shardformer/policies/auto_policy.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,10 @@ class PolicyLocation:
151151
file_name="blip2", class_name="Blip2ForConditionalGenerationPolicy"
152152
),
153153
# ChatGLM
154-
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMModel": PolicyLocation(
154+
"transformers_modules.modeling_chatglm.ChatGLMModel": PolicyLocation(
155155
file_name="chatglm2", class_name="ChatGLMModelPolicy"
156156
),
157-
"colossalai.shardformer.modeling.chatglm2_6b.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
157+
"transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration": PolicyLocation(
158158
file_name="chatglm2", class_name="ChatGLMForConditionalGenerationPolicy"
159159
),
160160
# Falcon
@@ -202,6 +202,13 @@ def _fullname(obj):
202202
module = klass.__module__
203203
if module == "builtins":
204204
return klass.__qualname__ # avoid outputs like 'builtins.str'
205+
# patch custom models which are not in transformers
206+
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
207+
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
208+
if module.startswith("transformers_modules"):
209+
split_module = module.split(".")
210+
if len(split_module) >= 2:
211+
module = f"{split_module[0]}.{split_module[-1]}"
205212
return module + "." + klass.__qualname__
206213

207214

@@ -220,7 +227,7 @@ def get_autopolicy(model: nn.Module) -> Policy:
220227

221228
if policy_location is None:
222229
raise NotImplementedError(
223-
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
230+
f"Auto policy for {model.__class__.__qualname__} ({full_name}) is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
224231
)
225232
else:
226233
policy = import_policy(policy_location)

0 commit comments

Comments
 (0)