Skip to content

Commit f560733

Browse files
[LLM] support bloom dybatch (#7003)
* support_llama * support_llama * delete_loss
1 parent e49842c commit f560733

File tree

8 files changed

+700
-37
lines changed

8 files changed

+700
-37
lines changed
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "helper.h"
16+
17+
template <typename T>
18+
__global__ void set_value_by_id(const int *seq_lens,
19+
const bool *stop_flags,
20+
const float *alibi_slopes,
21+
const int64_t *tgt_pos,
22+
T *output_data,
23+
int *sequence_lengths,
24+
int bs,
25+
int length,
26+
int num_head) {
27+
int bs_id = blockIdx.x;
28+
int hid = threadIdx.x;
29+
if (bs_id < bs) {
30+
T *output_data_now = output_data + bs_id * num_head * length + hid * length;
31+
float tgt_pos_now = static_cast<float>(tgt_pos[bs_id]);
32+
output_data_now[seq_lens[bs_id]] = static_cast<T>(tgt_pos_now * alibi_slopes[hid]);
33+
if (stop_flags[bs_id]) {
34+
sequence_lengths[bs_id] = 0;
35+
}
36+
}
37+
}
38+
39+
template <paddle::DataType D>
40+
std::vector<paddle::Tensor> set_mask_value(const paddle::Tensor& input_data,
41+
const paddle::Tensor& stop_flags,
42+
const paddle::Tensor& seq_lens,
43+
const paddle::Tensor& alibi_slopes,
44+
const paddle::Tensor& tgt_pos
45+
) {
46+
typedef PDTraits<D> traits_;
47+
typedef typename traits_::DataType DataType_;
48+
typedef typename traits_::data_t data_t;
49+
50+
PD_CHECK(seq_lens.dtype() == paddle::DataType::INT32);
51+
PD_CHECK(stop_flags.dtype() == paddle::DataType::BOOL);
52+
auto cu_stream = input_data.stream();
53+
std::vector<int64_t> input_data_shape = input_data.shape();
54+
std::vector<int64_t> seq_lens_shape = seq_lens.shape();
55+
auto sequence_lengths = seq_lens.copy_to(seq_lens.place(), false);
56+
57+
int input_bs = input_data_shape[0];
58+
int length = input_data_shape[3];
59+
int seq_bs = seq_lens_shape[0];
60+
int num_head = alibi_slopes.shape()[0];
61+
62+
int grid_size = input_bs;
63+
int block_size = num_head;
64+
set_value_by_id<<<grid_size, block_size, 0, cu_stream>>>(seq_lens.data<int>(),
65+
stop_flags.data<bool>(),
66+
alibi_slopes.data<float>(),
67+
tgt_pos.data<int64_t>(),
68+
reinterpret_cast<DataType_*>(const_cast<data_t*>(input_data.data<data_t>())),
69+
sequence_lengths.data<int>(), seq_bs, length, num_head);
70+
return {sequence_lengths};
71+
}
72+
73+
std::vector<paddle::Tensor> SetMaskValue(const paddle::Tensor& input_data,
74+
const paddle::Tensor& stop_flags,
75+
const paddle::Tensor& seq_lens,
76+
const paddle::Tensor& alibi_slopes,
77+
const paddle::Tensor& tgt_pos) {
78+
switch (input_data.type()) {
79+
case paddle::DataType::BFLOAT16: {
80+
return set_mask_value<paddle::DataType::BFLOAT16>(
81+
input_data,
82+
stop_flags,
83+
seq_lens,
84+
alibi_slopes,
85+
tgt_pos
86+
);
87+
}
88+
case paddle::DataType::FLOAT16: {
89+
return set_mask_value<paddle::DataType::FLOAT16>(
90+
input_data,
91+
stop_flags,
92+
seq_lens,
93+
alibi_slopes,
94+
tgt_pos
95+
);
96+
}
97+
case paddle::DataType::FLOAT32: {
98+
return set_mask_value<paddle::DataType::FLOAT32>(
99+
input_data,
100+
stop_flags,
101+
seq_lens,
102+
alibi_slopes,
103+
tgt_pos
104+
);
105+
}
106+
default: {
107+
PD_THROW(
108+
"NOT supported data type. "
109+
"Only float16, bfloat16 and float32 are supported. ");
110+
break;
111+
}
112+
}
113+
}
114+
115+
std::vector<std::vector<int64_t>> SetMaskValueInferShape(const std::vector<int64_t>& input_data_shape,
116+
const std::vector<int64_t>& stop_flags_shape,
117+
const std::vector<int64_t>& seq_lens_shape,
118+
const std::vector<int64_t>& alibi_slopes_shape,
119+
const std::vector<int64_t>& tgt_pos) {
120+
return {seq_lens_shape};
121+
}
122+
123+
std::vector<paddle::DataType> SetMaskValueInferDtype(const paddle::DataType& input_data_dtype,
124+
const paddle::DataType& stop_flags_dtype,
125+
const paddle::DataType& seq_lens_dtype,
126+
const paddle::DataType& alibi_slopes_dtype,
127+
const paddle::DataType& tgt_pos_dtype) {
128+
return {seq_lens_dtype};
129+
}
130+
131+
PD_BUILD_OP(set_alibi_mask_value)
132+
.Inputs({"input_data", "stop_flags", "seq_lens", "alibi_slopes", "tgt_pos"})
133+
.Outputs({"sequence_lengths"})
134+
.SetKernelFn(PD_KERNEL(SetMaskValue))
135+
.SetInferShapeFn(PD_INFER_SHAPE(SetMaskValueInferShape))
136+
.SetInferDtypeFn(PD_INFER_DTYPE(SetMaskValueInferDtype));

csrc/setup_cuda.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
"./generation/write_cache_kv.cu",
3232
"./generation/encode_rotary_qk.cu",
3333
"./generation/top_p_sampling.cu",
34+
"./generation/set_alibi_mask_value.cu",
3435
]
3536
),
3637
)

llm/predictor.py

Lines changed: 43 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import time
2020
from abc import abstractmethod
2121
from dataclasses import dataclass, field
22-
from distutils.command.config import config
2322

2423
import numpy as np
2524
import paddle
@@ -139,6 +138,7 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer = N
139138
self.tokenizer = tokenizer
140139
self.return_tensors = "pd"
141140
self.tensor_parallel_rank, self.tensor_parallel_degree = init_dist_env()
141+
self.model_config.tensor_parallel_rank, self.model_config.tensor_parallel_degree = init_dist_env()
142142

143143
def _preprocess(self, source):
144144
tokenized_source = self.tokenizer(
@@ -284,11 +284,11 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
284284
self.cache_kvs[0].shape[-3],
285285
self.cache_kvs[0].shape[-1],
286286
)
287-
total_max_length = config.src_length + config.max_length
288-
self.pre_ids = paddle.full([config.batch_size, total_max_length], -1, dtype="int64")
287+
self.total_max_length = config.src_length + config.max_length
288+
self.pre_ids = paddle.full([config.batch_size, self.total_max_length], -1, dtype="int64")
289289
if "chatglm" in self.architectures:
290290
self.attention_mask = paddle.ones(
291-
shape=(config.batch_size, 1, total_max_length, total_max_length),
291+
shape=(config.batch_size, 1, self.total_max_length, self.total_max_length),
292292
dtype=self.dtype,
293293
)
294294
self.tgt_pos = paddle.ones(
@@ -297,15 +297,17 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
297297
)
298298
else:
299299
self.attention_mask = paddle.zeros(
300-
shape=(config.batch_size, 1, total_max_length, total_max_length),
300+
shape=(config.batch_size, 1, self.total_max_length, self.total_max_length),
301301
dtype=self.dtype,
302302
)
303303

304304
self.tgt_generation_mask = paddle.zeros(
305-
shape=[config.batch_size, 1, 1, total_max_length],
305+
shape=[config.batch_size, 1, 1, self.total_max_length],
306306
dtype=self.dtype,
307307
)
308-
self.arange_tensor_encoder = paddle.zeros(shape=(config.batch_size, 1, total_max_length), dtype=self.dtype)
308+
self.arange_tensor_encoder = paddle.zeros(
309+
shape=(config.batch_size, 1, self.total_max_length), dtype=self.dtype
310+
)
309311

310312
if config.export_precache:
311313
if config.prefix_path:
@@ -342,6 +344,10 @@ def _postprocess(self, predictions):
342344
return None
343345

344346
def _preprocess(self, source):
347+
self.attention_mask[:] = 0
348+
self.tgt_generation_mask[:] = 0
349+
pre_caches_length = 0 if not self.config.export_precache else self.pre_caches[0].shape[-2]
350+
345351
if "chatglm" in self.architectures:
346352
inputs = dybatch_preprocess(
347353
self.tokenizer,
@@ -370,12 +376,12 @@ def _preprocess(self, source):
370376
)
371377
for i in range(inputs["input_ids"].shape[0]):
372378
length = inputs["seq_len_encoder"][i][0]
373-
self.attention_mask[i, 0, :length, :length] = paddle.tril(
379+
self.attention_mask[i, :, :length, :length] = paddle.tril(
374380
paddle.ones(shape=(length, length), dtype=self.config.dtype)
375381
)
376-
self.arange_tensor_encoder[i, 0, :length] = paddle.arange(length).astype(self.config.dtype)
382+
self.arange_tensor_encoder[i, :, :length] = paddle.arange(length).astype(self.config.dtype)
377383

378-
self.tgt_generation_mask[i, 0, 0, :length] = paddle.ones(shape=[1, length], dtype=self.config.dtype)
384+
self.tgt_generation_mask[i, :, 0, :length] = paddle.ones(shape=[1, length], dtype=self.config.dtype)
379385
# alibi encoder
380386
alibi_slopes = get_alibi_slopes(self.model_config.n_head)
381387
inputs["position_ids"] = paddle.to_tensor(alibi_slopes, dtype="float32")
@@ -402,16 +408,16 @@ def _preprocess(self, source):
402408
[
403409
inputs["input_ids"].shape[0],
404410
self.model_config.n_head // self.model_config.tensor_parallel_degree,
405-
self.config.max_length,
406-
self.config.max_length,
411+
self.total_max_length,
412+
self.total_max_length,
407413
]
408414
)
409415
alibi_decoder = alibi.expand(
410416
[
411417
inputs["input_ids"].shape[0],
412418
self.model_config.n_head // self.model_config.tensor_parallel_degree,
413419
1,
414-
self.config.max_length,
420+
self.total_max_length,
415421
]
416422
)
417423
self.attention_mask = (
@@ -422,7 +428,6 @@ def _preprocess(self, source):
422428
)
423429

424430
else:
425-
pre_caches_length = 0 if not self.config.export_precache else self.pre_caches[0].shape[-2]
426431
inputs = dybatch_preprocess(
427432
self.tokenizer,
428433
source,
@@ -655,7 +660,7 @@ def create_predictor(
655660
from paddlenlp.experimental.transformers import (
656661
LlamaForCausalLMInferenceModel as LlamaInferenceModel,
657662
)
658-
663+
659664
config.tensor_parallel_degree = tensor_parallel_degree
660665
config.tensor_parallel_rank = tensor_parallel_rank
661666
config.quant_bits = -1
@@ -679,6 +684,20 @@ def create_predictor(
679684
dtype=predictor_args.dtype,
680685
)
681686
model.eval()
687+
elif "bloom" in config.architectures[0].lower():
688+
from paddlenlp.experimental.transformers import (
689+
BloomForCausalLMInferenceModel,
690+
)
691+
692+
config.tensor_parallel_degree = tensor_parallel_degree
693+
config.tensor_parallel_rank = tensor_parallel_rank
694+
model = BloomForCausalLMInferenceModel.from_pretrained(
695+
predictor_args.model_name_or_path,
696+
config=config,
697+
dtype=predictor_args.dtype,
698+
)
699+
cache_kvs_shape = BloomForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size)
700+
model.eval()
682701
predictor = DygraphInferencePredictor(predictor_args, model=model, tokenizer=tokenizer)
683702
elif predictor_args.mode == "static":
684703
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
@@ -698,6 +717,15 @@ def create_predictor(
698717
config, predictor_args.batch_size
699718
)
700719
predictor = StaticInferencePredictor(predictor_args, cache_kvs_shape, tokenizer=tokenizer)
720+
elif "bloom" in config.architectures[0].lower():
721+
from paddlenlp.experimental.transformers import (
722+
BloomForCausalLMInferenceModel,
723+
)
724+
725+
cache_kvs_shape = BloomForCausalLMInferenceModel.get_cache_kvs_shape(config, predictor_args.batch_size)
726+
predictor = StaticInferencePredictor(
727+
predictor_args, cache_kvs_shape=cache_kvs_shape, tokenizer=tokenizer
728+
)
701729
else:
702730
raise ValueError("the `mode` should be one of [dynamic, static]")
703731
return predictor

paddlenlp/experimental/transformers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from .bloom import *
1516
from .chatglm import *
1617
from .fused_transformer_layers import *
1718
from .llama import *
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .modeling import *

0 commit comments

Comments
 (0)