@@ -114,30 +114,30 @@ We will follow this roadmap to develop Shardformer:
114
114
- [x] Unit Testing
115
115
- [ ] Policy Implementation
116
116
117
- | model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
118
- | :------: | : -----: | :-----: | :--------: | :---------: | :------: | :-----: | :-----: | :--------: | :---------: |
119
- | bert | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
120
- | t5 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
121
- | llama V1/V2 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
122
- | gpt2 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
123
- | opt | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
124
- | bloom | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
125
- | chatglm2 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
126
- | vit | [ √] | [ √] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
127
- | whisper | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ √] | [ ] | [ ] |
128
- | sam | [ √] | [ ] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
129
- | blip2 | [ √] | [ ] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
130
- | falcon | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ √] | [ ] | [ ] |
131
- | roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
132
- | albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
133
- | ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
134
- | gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
135
- | gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
136
- | beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
137
- | swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
138
- | swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
139
- | qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
140
- | mistral | [ √] | [ ] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
117
+ | model | tensor parallel | pipeline parallel | lazy initialization | xformer | flash attn2 | jit fused operator | fused layernorm | sequence parallel | overlap |
118
+ | :-----------:| :---------------: | :-----------------: | :-------------------: | :-------: | :-----------: | :------------------: | :---------------: | :-----------------: | :-------: |
119
+ | bert | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
120
+ | t5 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
121
+ | llama V1/V2 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
122
+ | gpt2 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
123
+ | opt | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
124
+ | bloom | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
125
+ | chatglm2 | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] | [ √] |
126
+ | vit | [ √] | [ √] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
127
+ | whisper | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ √] | [ ] | [ ] |
128
+ | sam | [ √] | [ ] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
129
+ | blip2 | [ √] | [ ] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
130
+ | falcon | [ √] | [ √] | [ √] | [ √] | [ √] | [ ] | [ √] | [ ] | [ ] |
131
+ | roberta | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
132
+ | albert | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
133
+ | ernie | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
134
+ | gpt-neo | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
135
+ | gpt-j | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
136
+ | beit | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
137
+ | swin | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
138
+ | swin V2 | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
139
+ | qwen | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] | [ ] |
140
+ | mistral | [ √] | [ ] | [ ] | [ √] | [ √] | [ √] | [ √] | [ ] | [ ] |
141
141
142
142
143
143
## 💡 API Design
@@ -391,6 +391,43 @@ _POLICY_LIST = {
391
391
}
392
392
```
393
393
394
+ # ### How to support those models in huggingface model hub but not in the transformers library
395
+
396
+ There are two cases:
397
+
398
+ 1 . the modeling file is in the `transformers` library but the model weight is not in the `transformers` library. E.g. model structure of " 01-ai/Yi-34B" is the same as LLaMA but the weight is not in the `transformers` library. In this case, we should support llama as usual and Yi- 34B is also supported by the llama policy. We do not need to add a new policy for Yi- 34B .
399
+ 2 . the modeling file is not in the `transformers` library, such as the " THUDM/chatglm2-6b" .
400
+
401
+ Take " THUDM/chatglm2-6b" as an example, we clearly illustrate how to support this model in the `shardformer` .
402
+
403
+ Unlike llama which is in `transformers` library, we cannot import chatglm2 model directly. Thus, the key in policy should be str of class name, rather than class itself.
404
+
405
+ E.g. for llama:
406
+ ```python
407
+ policy[LlamaDecoderLayer] = ModulePolicyDescription(... )
408
+ ```
409
+
410
+ for chatglm2:
411
+ ```python
412
+ policy[" GLMBlock" ] = ModulePolicyDescription(... )
413
+ ```
414
+
415
+ Then when registering such models in the autopolicy, we should follow below format :
416
+ ```python
417
+ " transformers_modules.<modeling_filename>.<class_name>" : PolicyLocation(
418
+ file_name = " <policy_filename>" , class_name = " <policy_class_name>"
419
+ )
420
+ ```
421
+
422
+ As for chatglm2 model, it should be:
423
+ ```python
424
+ " transformers_modules.modeling_chatglm.ChatGLMForConditionalGeneration" : PolicyLocation(
425
+ file_name = " chatglm2" , class_name = " ChatGLMForConditionalGenerationPolicy"
426
+ )
427
+ ```
428
+
429
+ When using such models, `AutoModel` is supported as usual. The policy will be automatically loaded by the autopolicy.
430
+
394
431
# ## Write Your Unit Testing
395
432
396
433
This section serves as the guideline for testing the `shardformer` module.
@@ -424,13 +461,13 @@ We conducted [benchmark tests](./examples/performance_benchmark.py) to evaluate
424
461
We set the batch size to 4 , the number of attention heads to 8 , and the head dimension to 64 . ' N_CTX' refers to the sequence length.
425
462
426
463
In the case of using 2 GPUs, the training times are as follows.
427
- | N_CTX | org_model | shard_model |
428
- | :------ : | :---- - : | :---- - : |
429
- | 256 | 11. 2ms | 17. 2ms |
430
- | 512 | 9. 8ms | 19. 5ms |
431
- | 1024 | 19. 6ms | 18. 9ms |
432
- | 2048 | 46. 6ms | 30. 8ms |
433
- | 4096 | 160. 5ms | 90. 4ms |
464
+ | N_CTX | org_model | shard_model |
465
+ | :---- - : | :-------- - : | :---------- - : |
466
+ | 256 | 11. 2ms | 17. 2ms |
467
+ | 512 | 9. 8ms | 19. 5ms |
468
+ | 1024 | 19. 6ms | 18. 9ms |
469
+ | 2048 | 46. 6ms | 30. 8ms |
470
+ | 4096 | 160. 5ms | 90. 4ms |
434
471
435
472
436
473
< p align = " center" >
@@ -440,13 +477,13 @@ In the case of using 2 GPUs, the training times are as follows.
440
477
441
478
In the case of using 4 GPUs, the training times are as follows.
442
479
443
- | N_CTX | org_model | shard_model |
444
- | :------ : | :---- - : | :---- - : |
445
- | 256 | 10. 0ms | 21. 1ms |
446
- | 512 | 11. 5ms | 20. 2ms |
447
- | 1024 | 22. 1ms | 20. 6ms |
448
- | 2048 | 46. 9ms | 24. 8ms |
449
- | 4096 | 160. 4ms | 68. 0ms |
480
+ | N_CTX | org_model | shard_model |
481
+ | :---- - : | :-------- - : | :---------- - : |
482
+ | 256 | 10. 0ms | 21. 1ms |
483
+ | 512 | 11. 5ms | 20. 2ms |
484
+ | 1024 | 22. 1ms | 20. 6ms |
485
+ | 2048 | 46. 9ms | 24. 8ms |
486
+ | 4096 | 160. 4ms | 68. 0ms |
450
487
451
488
452
489
@@ -475,10 +512,10 @@ warmup_fraction = 0.03
475
512
476
513
477
514
| accuracy | f1 | loss | GPU number | model sharded |
478
- | :------ : | :---- - : | :---- - : | :-------- : | :-------- - : |
479
- | 0.82971 | 0.87713 | 0.23194 | 4 | True |
480
- | 0.83797 | 0.88006 | 0.22683 | 2 | True |
481
- | 0.84521 | 0.88700 | 0.21822 | 1 | False |
515
+ | :-------- : | :------ - : | :------ - : | :---------- : | :------------ - : |
516
+ | 0.82971 | 0.87713 | 0.23194 | 4 | True |
517
+ | 0.83797 | 0.88006 | 0.22683 | 2 | True |
518
+ | 0.84521 | 0.88700 | 0.21822 | 1 | False |
482
519
483
520
484
521
Overall, the results demonstrate that using shardformers during model training does not affect the convergence.
0 commit comments