Skip to content

Commit e3aee9b

Browse files
wangxicodingZHUI
andauthored
[GPT-3] Support tensor model parallel in static graph. (#2245)
Co-authored-by: Zhong Hui <[email protected]>
1 parent 83f37a1 commit e3aee9b

File tree

3 files changed

+23
-9
lines changed

3 files changed

+23
-9
lines changed

examples/language_model/gpt-3/static/modeling.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -910,17 +910,20 @@ def _init_generation_caches(self, src_ids):
910910

911911
def parallel_matmul(self, lm_output, logit_weights, parallel_output, topo):
912912
if topo is not None and topo.mp_info.size > 1:
913+
hybrid_groups = fleet.get_hybrid_communicate_group()
914+
model_parallel_group = hybrid_groups.get_model_parallel_group()
915+
913916
input_parallel = paddle.distributed.collective._c_identity(
914-
lm_output, group=None)
917+
lm_output, group=model_parallel_group)
915918

916919
logits = paddle.matmul(
917920
input_parallel, logit_weights, transpose_y=True)
918921

919922
if parallel_output:
920923
return logits
921924

922-
# TODO(qinqing): collective._c_concat is not support in static graph now
923-
return paddle.distributed.collective._c_concat(logits, group=None)
925+
return paddle.distributed.collective._c_concat(
926+
logits, group=model_parallel_group)
924927
else:
925928
logits = paddle.matmul(lm_output, logit_weights, transpose_y=True)
926929
return logits

examples/language_model/gpt-3/static/run_gen.sh

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,6 @@ python -u -m paddle.distributed.fleet.launch \
1919
--max_seq_len 1024 \
2020
--micro_batch_size 2 \
2121
--global_batch_size 2 \
22-
--sharding_degree 1 \
23-
--mp_degree 1 \
24-
--dp_degree 1 \
25-
--pp_degree 1 \
2622
--max_dec_len 20 \
2723
--decoding_strategy 'topk_sampling' \
2824
--topp 0.9 \

examples/language_model/gpt-3/static/run_generation.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,11 +114,26 @@ def do_generation(args):
114114
# Initialize the paddle and paddle fleet execute environment
115115
paddle.enable_static()
116116

117+
assert args.dp_degree == 1, "Data parallel is not supported in inference"
118+
assert args.sharding_degree == 1, "Sharding parallel is temporarily not supported in inference"
119+
assert args.pp_degree == 1, "Pipeline parallel will be supported later"
120+
121+
if args.mp_degree == 1:
122+
args.mp_degree = paddle.distributed.get_world_size()
123+
else:
124+
assert args.mp_degree == paddle.distributed.get_world_size(), \
125+
"If mp_degree is specified, the size must be the same as world_size"
126+
117127
strategy = fleet.DistributedStrategy()
118-
strategy.hybrid_configs = {"dp_degree": 1, "mp_degree": 2, "pp_degree": 1}
128+
strategy.tensor_parallel = True
129+
strategy.tensor_parallel_configs = {
130+
"tensor_parallel_degree": args.mp_degree
131+
}
132+
119133
fleet.init(is_collective=True, strategy=strategy)
120134

121-
group = paddle.distributed.init_parallel_env()
135+
# temp use dynamic init, use HybridParallelInferenceHelper in future?
136+
paddle.distributed.init_parallel_env()
122137

123138
# Create the random seed for the worker
124139
random.seed(args.seed)

0 commit comments

Comments
 (0)