Skip to content

Commit 7780497

Browse files
authored
[GPT-3] Support generation code for GPT-3 in static graph. (#2188)
* Support GPT-3 generation in static graph * Support batch generation and fix code style
1 parent 1fdf308 commit 7780497

File tree

10 files changed

+690
-33
lines changed

10 files changed

+690
-33
lines changed

examples/language_model/gpt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
../../model_zoo/gpt
1+
../../model_zoo/gpt/

examples/language_model/gpt-3/README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,5 +144,16 @@ python -u -m paddle.distributed.fleet.launch \
144144

145145
除了上述混合并行策略外,飞桨还支持重计算、offload、混合精度等策略,来减少显存占用、加速训练。更多具体内容可以参考稿件:[飞桨分布式训练又推新品,4D混合并行可训千亿级AI模型](https://baijiahao.baidu.com/s?id=1697085717806202673)
146146

147+
### 飞桨超大模型部署
148+
149+
飞桨超大模型部署工具:
150+
151+
- Paddle Fleet: 飞桨训练自适应并行技术,同样适应于超大模型部署,针对推理硬件自适应切分
152+
- Paddle Inference: 支持模型并行、流水线并行、混合并行策略,经过极致优化,性能领先
153+
- Paddle Serving: 支持服务化部署,支持自动Batch、容错调度、服务监控、负载均衡
154+
- Paddle Slim: 支持超大模型量化、稀疏压缩
155+
156+
具体部署示例参考[GPT-3超大模型部署教程](deploy)
157+
147158
### 参考文献
148159
- [Language Models are Few-Shot Learners](https://arxiv.org/pdf/2005.14165.pdf)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
## 超大模型部署
2+
3+
TBD
4+
5+
### 模型导出
6+
7+
### 自动切分
8+
9+
### 推理部署
10+
11+
### Benchmark

examples/language_model/gpt-3/static/args.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ def parse_args(MODEL_CLASSES):
160160
type=int,
161161
default=10,
162162
help="Evaluate the model use X steps data.")
163-
164163
# Config for 4D Parallelism
165164
parser.add_argument(
166165
"--use_sharding",
@@ -258,6 +257,46 @@ def parse_args(MODEL_CLASSES):
258257
default=None,
259258
help='The option of profiler, which should be in format \"key1=value1;key2=value2;key3=value3\".'
260259
)
260+
parser.add_argument(
261+
"--max_dec_len",
262+
type=int,
263+
default=20,
264+
help="The maximum length of decoded sequence.", )
265+
parser.add_argument(
266+
"--decoding_strategy",
267+
type=str,
268+
default="topk_sampling",
269+
choices=["topk_sampling", "topp_sampling", "sampling"],
270+
help="The decoding strategy, not support beam_search now!", )
271+
parser.add_argument(
272+
"--temperature",
273+
type=float,
274+
default=1.,
275+
help="The temperature in each generation step.")
276+
# top-k sampling
277+
parser.add_argument(
278+
"--topk",
279+
type=int,
280+
default=10,
281+
help="The hyper-parameter in top-k sampling..")
282+
# top-p sampling
283+
parser.add_argument(
284+
"--topp",
285+
type=float,
286+
default=0.9,
287+
help="The hyper-parameter in top-p sampling.")
288+
# beam search
289+
parser.add_argument(
290+
"--beam_size",
291+
type=int,
292+
default=1,
293+
help="The hyper-parameter in beam search.")
294+
parser.add_argument(
295+
"--save_inference_model_then_exist",
296+
type=bool,
297+
default=False,
298+
help="save_inference_model_then_exist")
299+
261300
args = parser.parse_args()
262301
args.test_iters = args.eval_iters * 10
263302

examples/language_model/gpt-3/static/dataset.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def _num_tokens(documents, lens):
148148

149149

150150
def _num_epochs(tokens_per_epoch, seq_length, num_samples):
151-
"""Based on number of samples and sequence lenght, calculate how many
151+
"""Based on number of samples and sequence length, calculate how many
152152
epochs will be needed."""
153153
num_epochs = 0
154154
total_tokens = 0
@@ -256,18 +256,17 @@ def get_train_valid_test_split_(splits_string, size):
256256
return splits_index
257257

258258

259-
def create_pretrained_dataset(
260-
args,
261-
input_path,
262-
local_rank,
263-
data_world_rank,
264-
data_world_size,
265-
eos_id,
266-
worker_init=None,
267-
max_seq_len=1024,
268-
places=None,
269-
data_holders=None,
270-
pipeline_mode=False, ):
259+
def create_pretrained_dataset(args,
260+
input_path,
261+
local_rank,
262+
data_world_rank,
263+
data_world_size,
264+
eos_id,
265+
worker_init=None,
266+
max_seq_len=1024,
267+
places=None,
268+
data_holders=None,
269+
pipeline_mode=False):
271270

272271
if local_rank == 0:
273272
start_time = time.time()
@@ -339,7 +338,8 @@ def build_dataset(index, name, num_samples):
339338
sample_lens=sample_lens,
340339
eos_id=eos_id,
341340
seed=args.seed,
342-
use_pure_fp16=args.use_amp and args.amp_level == "O2")
341+
use_pure_fp16=args.use_amp and args.amp_level == "O2",
342+
data_holders=data_holders)
343343
batch_sampler = DistributedBatchSampler(
344344
dataset,
345345
batch_size=args.micro_batch_size,
@@ -361,14 +361,16 @@ def data_gen():
361361
data_loader.set_sample_generator(
362362
data_gen, batch_size=args.micro_batch_size, places=places)
363363
else:
364+
stacks = (Stack(), ) * len(data_holders)
365+
collate_fn = Tuple(*stacks)
364366
data_loader = DataLoader(
365367
dataset=dataset,
366368
places=places,
367369
feed_list=data_holders,
368370
batch_sampler=batch_sampler,
369371
num_workers=1,
370372
worker_init_fn=worker_init,
371-
collate_fn=Tuple(Stack(), Stack(), Stack(), Stack()),
373+
collate_fn=collate_fn,
372374
return_list=False)
373375
return data_loader
374376

@@ -401,7 +403,8 @@ def __init__(self,
401403
name="gpt",
402404
max_seq_len=1024,
403405
seed=1234,
404-
use_pure_fp16=False):
406+
use_pure_fp16=False,
407+
data_holders=None):
405408
self.file_prefix = file_prefix
406409
self.max_seq_len = max_seq_len
407410
self.name = name
@@ -410,6 +413,7 @@ def __init__(self,
410413
self.sample_lens = sample_lens
411414
self.micro_batch_size = micro_batch_size
412415
self.use_pure_fp16 = use_pure_fp16
416+
self.data_holders = data_holders
413417

414418
if documents is None:
415419
document_ids = np.arange(0, self.sample_lens.shape[0])
@@ -435,10 +439,17 @@ def _construct_sample(self, tokens):
435439
else:
436440
loss_mask = np.ones(seq_length, dtype="float32")
437441
loss_mask[np.where(np.array(tokens) == self.eos_id)] = 0.0
438-
position_ids = np.arange(0, seq_length, dtype="int64")
439442

443+
position_ids = np.arange(0, seq_length, dtype="int64")
440444
labels = np.array(labels, dtype="int64")
441-
return [tokens, loss_mask, position_ids, labels]
445+
if len(self.data_holders) == 4:
446+
return [tokens, loss_mask, position_ids, labels]
447+
elif len(self.data_holders) == 3:
448+
return [tokens, loss_mask, position_ids]
449+
else:
450+
assert len(self.data_holders) == 1, \
451+
"length of daat_holders should be 4, 3 or 1"
452+
return [tokens]
442453

443454
def _get_single_sample_from_idx(self, doc_index_f, doc_index_l, offset_f,
444455
offset_l):

0 commit comments

Comments
 (0)