Skip to content

Commit 326cf82

Browse files
authored
Enable tensor parallel tests. (#8757)
1 parent bafd7e5 commit 326cf82

File tree

1 file changed

+26
-11
lines changed

1 file changed

+26
-11
lines changed

tests/transformers/test_tensor_parallel.py

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,17 @@
1515

1616
import glob
1717
import os
18+
import sys
1819
import tempfile
19-
import unittest
20+
from pathlib import Path
2021

2122
import paddle
2223
from paddle.distributed import fleet
2324

25+
sys.path.append(str(Path(__file__).parent.parent.parent))
26+
from tests.parallel_launch import TestMultipleGpus
27+
from tests.testing_utils import require_paddle_at_least_2_gpu
28+
2429
tp_size = paddle.distributed.get_world_size()
2530
tp_rank = 0
2631
if tp_size > 1:
@@ -42,6 +47,7 @@ def prepare_config(config):
4247
config.num_layers = 2
4348
config.num_hidden_layers = 2
4449
config.num_attention_heads = 16
50+
config.num_key_value_heads = 16
4551
config.intermediate_size = config.hidden_size * 3
4652
config.tensor_parallel_degree = tp_size
4753
config.tensor_parallel_rank = tp_rank
@@ -118,6 +124,15 @@ def _test_bloom():
118124
common_test_merge(model, BloomForCausalLM)
119125

120126

127+
def _test_qwen2():
128+
from paddlenlp.transformers import Qwen2Config, Qwen2ForCausalLM
129+
130+
config = Qwen2Config()
131+
config = prepare_config(config)
132+
model = Qwen2ForCausalLM.from_config(config)
133+
common_test_merge(model, Qwen2ForCausalLM)
134+
135+
121136
def _test_gemma():
122137
from paddlenlp.transformers import GemmaConfig, GemmaForCausalLM
123138

@@ -127,15 +142,15 @@ def _test_gemma():
127142
common_test_merge(model, GemmaForCausalLM)
128143

129144

130-
# _test_llama()
131-
# _test_chatglm()
132-
# _test_bloom()
145+
@require_paddle_at_least_2_gpu
146+
class TestTensorParallel(TestMultipleGpus):
147+
def test_model_load_merge(self):
148+
self.run_2gpu(__file__)
133149

134150

135-
class TestTensorParallel(unittest.TestCase):
136-
@unittest.skipIf(tp_size < 2, "Need muti-gpu to run this test!")
137-
def test_model_load_merge(self):
138-
_test_llama()
139-
_test_chatglm()
140-
_test_bloom()
141-
_test_gemma()
151+
if __name__ == "__main__":
152+
_test_llama()
153+
_test_chatglm()
154+
_test_bloom()
155+
_test_gemma()
156+
_test_qwen2()

0 commit comments

Comments
 (0)