15
15
16
16
import glob
17
17
import os
18
+ import sys
18
19
import tempfile
19
- import unittest
20
+ from pathlib import Path
20
21
21
22
import paddle
22
23
from paddle .distributed import fleet
23
24
25
+ sys .path .append (str (Path (__file__ ).parent .parent .parent ))
26
+ from tests .parallel_launch import TestMultipleGpus
27
+ from tests .testing_utils import require_paddle_at_least_2_gpu
28
+
24
29
tp_size = paddle .distributed .get_world_size ()
25
30
tp_rank = 0
26
31
if tp_size > 1 :
@@ -42,6 +47,7 @@ def prepare_config(config):
42
47
config .num_layers = 2
43
48
config .num_hidden_layers = 2
44
49
config .num_attention_heads = 16
50
+ config .num_key_value_heads = 16
45
51
config .intermediate_size = config .hidden_size * 3
46
52
config .tensor_parallel_degree = tp_size
47
53
config .tensor_parallel_rank = tp_rank
@@ -118,6 +124,15 @@ def _test_bloom():
118
124
common_test_merge (model , BloomForCausalLM )
119
125
120
126
127
+ def _test_qwen2 ():
128
+ from paddlenlp .transformers import Qwen2Config , Qwen2ForCausalLM
129
+
130
+ config = Qwen2Config ()
131
+ config = prepare_config (config )
132
+ model = Qwen2ForCausalLM .from_config (config )
133
+ common_test_merge (model , Qwen2ForCausalLM )
134
+
135
+
121
136
def _test_gemma ():
122
137
from paddlenlp .transformers import GemmaConfig , GemmaForCausalLM
123
138
@@ -127,15 +142,15 @@ def _test_gemma():
127
142
common_test_merge (model , GemmaForCausalLM )
128
143
129
144
130
- # _test_llama()
131
- # _test_chatglm()
132
- # _test_bloom()
145
+ @require_paddle_at_least_2_gpu
146
+ class TestTensorParallel (TestMultipleGpus ):
147
+ def test_model_load_merge (self ):
148
+ self .run_2gpu (__file__ )
133
149
134
150
135
- class TestTensorParallel (unittest .TestCase ):
136
- @unittest .skipIf (tp_size < 2 , "Need muti-gpu to run this test!" )
137
- def test_model_load_merge (self ):
138
- _test_llama ()
139
- _test_chatglm ()
140
- _test_bloom ()
141
- _test_gemma ()
151
+ if __name__ == "__main__" :
152
+ _test_llama ()
153
+ _test_chatglm ()
154
+ _test_bloom ()
155
+ _test_gemma ()
156
+ _test_qwen2 ()
0 commit comments