Skip to content

Commit 71b3be3

Browse files
authored
[LLM Inference] add --use_fake_parameter option for ptq fake scales and fix compute error of total_max_length (#8955)
* update some code * update * update * update * update tune_cublaslt_gemm demo * fix step in tune_cublaslt_gemm
1 parent 0ec78aa commit 71b3be3

File tree

7 files changed

+249
-53
lines changed

7 files changed

+249
-53
lines changed

csrc/generation/test_tune_cublaslt_gemm.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,31 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from paddlenlp_ops import tune_cublaslt_gemm
1615
import paddle
16+
from paddlenlp_ops import tune_cublaslt_gemm
17+
18+
M_tensor = paddle.to_tensor([32768])
19+
20+
# llama3.1-8b
21+
k1 = [4096, 4096, 4096, 14336]
22+
n1 = [6144, 4096, 28672, 4096]
23+
24+
# llama3.1-405b mp=8
25+
k2 = [16384, 16384, 16384, 6656]
26+
n2 = [2560, 16384, 13312, 16384]
27+
28+
# qwen2-1.5b
29+
k3 = [1536, 1536, 1536, 8960]
30+
n3 = [2048, 1536, 17920, 1536]
31+
32+
# qwen2-7b
33+
k4 = [3584, 3584, 3584, 18944]
34+
n4 = [4608, 3584, 37888, 3584]
1735

18-
M_tensor = paddle.to_tensor([1024])
19-
K_tensor = paddle.to_tensor([1024, 2048])
20-
N_tensor = paddle.to_tensor([4096, 8192])
36+
K_tensor = paddle.to_tensor(k1 + k2 + k3 + k4)
37+
N_tensor = paddle.to_tensor(n1 + n2 + n3 + n4)
2138

2239
Dtype = "int8"
23-
Path = "./search.csv"
40+
Path = "./cublaslt_gemm_search.csv"
2441

2542
tune_cublaslt_gemm(M_tensor, K_tensor, N_tensor, Dtype, True, False, Path)

csrc/generation/tune_cublaslt_gemm.cu

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,9 @@ void TuneCublasltGemm(const paddle::Tensor& M,
759759
case 1024:
760760
step = 1024;
761761
break;
762+
case 8192:
763+
step = 4096;
764+
break;
762765
}
763766
}
764767

llm/predict/export_model.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,15 @@ class ExportArgument:
2929
output_path: str = field(default=None, metadata={"help": "The output path of model."})
3030

3131

32+
def add_inference_args_to_config(model_config, args):
33+
"""Add export arguments to config."""
34+
model_config.infer_model_block_size = args.block_size
35+
model_config.infer_model_max_seq_len = args.total_max_length
36+
model_config.infer_model_cachekv_int8_type = args.cachekv_int8_type
37+
model_config.infer_model_dtype = args.dtype
38+
model_config.infer_model_paddle_commit = paddle.version.commit
39+
40+
3241
def main():
3342
parser = PdArgumentParser((PredictorArgument, ModelArgument, ExportArgument))
3443
predictor_args, model_args, export_args = parser.parse_args_into_dataclasses()
@@ -60,6 +69,7 @@ def main():
6069
"cachekv_int8_type": predictor_args.cachekv_int8_type,
6170
},
6271
)
72+
add_inference_args_to_config(predictor.model.config, predictor_args)
6373
predictor.model.config.save_pretrained(export_args.output_path)
6474
if predictor.generation_config is not None:
6575
predictor.generation_config.save_pretrained(export_args.output_path)

llm/predict/predictor.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ class PredictorArgument:
101101
"help": "If benchmark set as `True`, we will force model decode to max_length, which is helpful to compute throughput. "
102102
},
103103
)
104-
104+
use_fake_parameter: bool = field(default=False, metadata={"help": "use fake parameter, for ptq scales now."})
105105
block_attn: bool = field(default=False, metadata={"help": "whether use block attention"})
106106
block_size: int = field(default=64, metadata={"help": "the block size for cache_kvs."})
107107
cachekv_int8_type: str = field(
@@ -124,7 +124,7 @@ class PredictorArgument:
124124

125125
@property
126126
def total_max_length(self):
127-
return self.src_length + self.max_length
127+
return 8192 # Maximum sequence length.
128128

129129

130130
@dataclass
@@ -948,8 +948,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
948948
result_queue = mp.Queue()
949949
tensor_queue = mp.Queue()
950950

951-
output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64")
952-
output_tensor = output_tensor.cpu()
951+
output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64").cpu()
953952
tensor_queue.put(output_tensor)
954953

955954
read_res_process = mp.Process(
@@ -1074,8 +1073,7 @@ def predict(self, input_texts: list[str], return_tokens=False):
10741073
result_queue = mp.Queue()
10751074
tensor_queue = mp.Queue()
10761075

1077-
output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64")
1078-
output_tensor = output_tensor.cpu()
1076+
output_tensor = paddle.full(shape=[MAX_BSZ + 2, 1], fill_value=2, dtype="int64").cpu()
10791077
tensor_queue.put(output_tensor)
10801078

10811079
read_res_process = mp.Process(
@@ -1108,10 +1106,11 @@ def predict(self, input_texts: list[str], return_tokens=False):
11081106

11091107
def get_ptq_multicards_num(directory):
11101108
count = 0
1111-
prefix = "act_scales_"
1112-
for filename in os.listdir(directory):
1113-
if filename.startswith(prefix):
1114-
count += 1
1109+
if os.path.exists(directory):
1110+
prefix = "act_scales_"
1111+
for filename in os.listdir(directory):
1112+
if filename.startswith(prefix):
1113+
count += 1
11151114
return count
11161115

11171116

@@ -1204,6 +1203,7 @@ def create_predictor(
12041203
config.model_name_or_path = predictor_args.model_name_or_path
12051204
config.quant_type = predictor_args.quant_type
12061205
config.cachekv_int8_type = predictor_args.cachekv_int8_type
1206+
config.use_fake_parameter = predictor_args.use_fake_parameter
12071207
config.single_card_ptq = True
12081208
if predictor_args.avx_model:
12091209
config.avx_type = predictor_args.avx_type
@@ -1381,15 +1381,10 @@ def create_predictor(
13811381

13821382
elif predictor_args.mode == "static":
13831383
config = AutoConfig.from_pretrained(predictor_args.model_name_or_path)
1384-
config.quant_type = predictor_args.quant_type
1385-
config.cachekv_int8_type = predictor_args.cachekv_int8_type
13861384

13871385
if config.quantization_config.quant_type is not None:
1388-
predictor_args.quant_type = config.quantization_config.quant_type
1389-
config.quant_type = config.quantization_config.quant_type
1390-
if "c8" in config.quant_type:
1386+
if "c8" in config.quantization_config.quant_type:
13911387
predictor_args.cachekv_int8_type = "static"
1392-
config.cachekv_int8_type = "static"
13931388

13941389
if "llama" in config.architectures[0].lower():
13951390
if predictor_args.block_attn:

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 120 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@
4343
GenerationBlockInferenceModel,
4444
GenerationInferenceModel,
4545
)
46-
from paddlenlp.experimental.transformers.utils import infererence_model_from_pretrained
46+
from paddlenlp.experimental.transformers.utils import (
47+
EmptyActScale,
48+
EmptyCacheScale,
49+
EmptyWeightScale,
50+
infererence_model_from_pretrained,
51+
)
4752
from paddlenlp.transformers import LlamaConfig, LlamaPretrainedModel
4853
from paddlenlp.transformers.conversion_utils import split_param_func
4954
from paddlenlp.transformers.llama.modeling import LlamaLMHead
@@ -346,7 +351,7 @@ def __init__(self, config: LlamaConfig):
346351
self.num_layers = config.num_hidden_layers
347352
self.epsilon = config.rms_norm_eps
348353
self.max_position_embeddings = config.max_position_embeddings
349-
self.quant_type = config.quant_type
354+
self.quant_type = config.get("quant_type", "")
350355

351356
self.rope_theta = config.rope_theta
352357
self.use_neox = True
@@ -364,6 +369,8 @@ def __init__(self, config: LlamaConfig):
364369
self.smooth = config.quantization_config.smooth
365370
self.shift_smooth_all_linears = config.quantization_config.shift_smooth_all_linears
366371

372+
self.use_fake_parameter = config.get("use_fake_parameter", False)
373+
367374
if self.use_weight_only:
368375
assert (
369376
self.quant_type == "weight_only_int8" or self.quant_type == "weight_only_int4"
@@ -894,6 +901,30 @@ def set_state_dict(self, state_dict):
894901

895902
if "a8w8" in self.quant_type:
896903
if self.shift_smooth_all_linears:
904+
if self.use_fake_parameter:
905+
if "llama.layers.{}.self_attn.o_proj.shift_bias".format(idx) not in state_dict:
906+
state_dict["llama.layers.{}.self_attn.o_proj.shift_bias".format(idx)] = paddle.zeros(
907+
shape=[
908+
(self.num_attention_heads // self.config.tensor_parallel_degree)
909+
* (self.hidden_size // self.num_attention_heads)
910+
],
911+
dtype=paddle.get_default_dtype(),
912+
)
913+
state_dict["llama.layers.{}.self_attn.o_proj.smooth_weight".format(idx)] = paddle.ones(
914+
shape=[
915+
(self.num_attention_heads // self.config.tensor_parallel_degree)
916+
* (self.hidden_size // self.num_attention_heads)
917+
],
918+
dtype=paddle.get_default_dtype(),
919+
)
920+
state_dict["llama.layers.{}.mlp.down_proj.shift_bias".format(idx)] = paddle.zeros(
921+
shape=[self.intermediate_size // self.config.tensor_parallel_degree],
922+
dtype=paddle.get_default_dtype(),
923+
)
924+
state_dict["llama.layers.{}.mlp.down_proj.smooth_weight".format(idx)] = paddle.ones(
925+
shape=[self.intermediate_size // self.config.tensor_parallel_degree],
926+
dtype=paddle.get_default_dtype(),
927+
)
897928
self.transformer_block.linear_shifts[idx].set_value(
898929
paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.shift_bias".format(idx)])
899930
)
@@ -908,6 +939,33 @@ def set_state_dict(self, state_dict):
908939
)
909940

910941
if self.shift:
942+
if self.use_fake_parameter:
943+
if "llama.layers.{}.input_layernorm.bias".format(idx) not in state_dict:
944+
state_dict["llama.layers.{}.input_layernorm.bias".format(idx)] = paddle.zeros(
945+
shape=[self.hidden_size], dtype=paddle.get_default_dtype()
946+
)
947+
state_dict["llama.layers.{}.post_attention_layernorm.bias".format(idx)] = paddle.zeros(
948+
[self.hidden_size], dtype=paddle.get_default_dtype()
949+
)
950+
unfused_state_dict["self_attn.q_proj.bias"] = paddle.zeros(
951+
shape=[self.num_attention_heads * (self.hidden_size // self.num_attention_heads)],
952+
dtype=paddle.get_default_dtype(),
953+
)
954+
unfused_state_dict["self_attn.k_proj.bias"] = paddle.zeros(
955+
shape=[self.num_key_value_heads * (self.hidden_size // self.num_attention_heads)],
956+
dtype=paddle.get_default_dtype(),
957+
)
958+
unfused_state_dict["self_attn.v_proj.bias"] = paddle.zeros(
959+
shape=[self.num_key_value_heads * (self.hidden_size // self.num_attention_heads)],
960+
dtype=paddle.get_default_dtype(),
961+
)
962+
unfused_state_dict["mlp.gate_proj.bias"] = paddle.zeros(
963+
shape=[self.intermediate_size], dtype=paddle.get_default_dtype()
964+
)
965+
unfused_state_dict["mlp.up_proj.bias"] = paddle.zeros(
966+
shape=[self.intermediate_size], dtype=paddle.get_default_dtype()
967+
)
968+
911969
self.transformer_block.ln_biases[idx].set_value(
912970
paddle.to_tensor(state_dict["llama.layers.{}.input_layernorm.bias".format(idx)])
913971
)
@@ -948,6 +1006,14 @@ def set_state_dict(self, state_dict):
9481006
self.transformer_block.ffn1_biases[idx].set_value(paddle.to_tensor(concated_ffn1_bias))
9491007

9501008
if self.shift_smooth_all_linears:
1009+
if self.use_fake_parameter:
1010+
if "llama.layers.{}.self_attn.o_proj.bias".format(idx) not in state_dict:
1011+
state_dict["llama.layers.{}.self_attn.o_proj.bias".format(idx)] = paddle.zeros(
1012+
[self.hidden_size], dtype=paddle.get_default_dtype()
1013+
)
1014+
state_dict["llama.layers.{}.mlp.down_proj.layer.bias".format(idx)] = paddle.zeros(
1015+
[self.hidden_size], dtype=paddle.get_default_dtype()
1016+
)
9511017
self.transformer_block.linear_biases[idx].set_value(
9521018
paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.bias".format(idx)])
9531019
)
@@ -981,41 +1047,64 @@ def set_state_dict(self, state_dict):
9811047
weight_scale_map_dict = scale_map_dict["weight_scale"]
9821048
cache_scale_map_dict = scale_map_dict["cachekv_scale"]
9831049

984-
act_scale_json_path = os.path.join(self.quant_model_path, "act_scales.json")
985-
weight_scale_json_path = os.path.join(self.quant_model_path, "weight_scales.json")
986-
if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq:
987-
act_scale_json_path = os.path.join(
988-
self.quant_model_path, f"act_scales_{self.config.tensor_parallel_rank}.json"
1050+
if not self.use_fake_parameter:
1051+
act_scale_json_path = os.path.join(self.quant_model_path, "act_scales.json")
1052+
weight_scale_json_path = os.path.join(self.quant_model_path, "weight_scales.json")
1053+
if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq:
1054+
act_scale_json_path = os.path.join(
1055+
self.quant_model_path, f"act_scales_{self.config.tensor_parallel_rank}.json"
1056+
)
1057+
weight_scale_json_path = os.path.join(
1058+
self.quant_model_path, f"weight_scales_{self.config.tensor_parallel_rank}.json"
1059+
)
1060+
act_scale_loader = ActScalesLoader(
1061+
act_scale_json_path, act_scale_map_dict, num_of_layers=self.config.num_hidden_layers
9891062
)
990-
weight_scale_json_path = os.path.join(
991-
self.quant_model_path, f"weight_scales_{self.config.tensor_parallel_rank}.json"
1063+
weight_scales_loader = WeightScalesLoader(
1064+
weight_scale_json_path,
1065+
weight_scale_map_dict,
1066+
num_of_layers=self.config.num_hidden_layers,
1067+
concat_qkv=True,
1068+
concat_ffn1=True,
1069+
)
1070+
else:
1071+
act_scale_loader = EmptyActScale(act_scale_map_dict, num_of_layers=self.config.num_hidden_layers)
1072+
weight_scales_loader = EmptyWeightScale(
1073+
weight_scale_map_dict,
1074+
num_of_layers=self.config.num_hidden_layers,
1075+
num_head=self.num_attention_heads,
1076+
dim_head=self.hidden_size // self.num_attention_heads,
1077+
ffn_hidden_size=self.intermediate_size,
1078+
num_key_value_heads=self.num_key_value_heads,
1079+
mp_size=self.config.tensor_parallel_degree,
9921080
)
993-
act_scale_loader = ActScalesLoader(
994-
act_scale_json_path, act_scale_map_dict, num_of_layers=self.config.num_hidden_layers
995-
)
9961081
self.transformer_block.act_scales = act_scale_loader.scale
9971082

998-
weight_scales_loader = WeightScalesLoader(
999-
weight_scale_json_path,
1000-
weight_scale_map_dict,
1001-
num_of_layers=self.config.num_hidden_layers,
1002-
concat_qkv=True,
1003-
concat_ffn1=True,
1004-
)
1005-
10061083
if self.config.cachekv_int8_type == "static":
1007-
cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_scales.json")
1008-
if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq:
1009-
cache_scale_json_path = os.path.join(
1010-
self.quant_model_path, f"cachekv_scales_{self.config.tensor_parallel_rank}.json"
1084+
if not self.use_fake_parameter:
1085+
cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_scales.json")
1086+
if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq:
1087+
cache_scale_json_path = os.path.join(
1088+
self.quant_model_path, f"cachekv_scales_{self.config.tensor_parallel_rank}.json"
1089+
)
1090+
cache_scales_loader = CacheScaleLoader(
1091+
cache_scale_json_path,
1092+
cache_scale_map_dict,
1093+
num_of_layers=self.config.num_hidden_layers,
1094+
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
1095+
num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree,
10111096
)
1012-
cache_scales_loader = CacheScaleLoader(
1013-
cache_scale_json_path,
1014-
cache_scale_map_dict,
1015-
num_of_layers=self.config.num_hidden_layers,
1016-
num_heads=self.num_attention_heads // self.config.tensor_parallel_degree,
1017-
num_key_value_heads=self.num_key_value_heads // self.config.tensor_parallel_degree,
1018-
)
1097+
else:
1098+
cache_scales_loader = EmptyCacheScale(
1099+
cache_scale_map_dict,
1100+
num_of_layers=self.config.num_hidden_layers,
1101+
num_heads=self.num_attention_heads,
1102+
dim_heads=self.hidden_size // self.num_attention_heads,
1103+
is_channel_wise=False,
1104+
num_key_value_heads=self.num_key_value_heads,
1105+
mp_size=self.config.tensor_parallel_degree,
1106+
)
1107+
10191108
for k, v in cache_scales_loader.scale.items():
10201109
for i_layer, weight_scale in enumerate(v):
10211110
weight_scale = weight_scale.astype("float32")

0 commit comments

Comments
 (0)