diff --git a/llm/predict/predictor.py b/llm/predict/predictor.py index ddac513e4890..b7d6d17590f3 100644 --- a/llm/predict/predictor.py +++ b/llm/predict/predictor.py @@ -196,6 +196,7 @@ class PredictorArgument: ) dynamic_insert: bool = field(default=False, metadata={"help": "whether use dynamic insert"}) total_request_num: int = field(default=None, metadata={"help": "The total number of request data"}) + kv_cache_reuse: int = field(default=1) def __post_init__(self): if self.speculate_method is not None: @@ -1155,6 +1156,11 @@ def init_cache_kvs(self): for cache_k_shape, cache_v_shape in zip(self.cache_k_shapes, self.cache_v_shapes): self.cache_kvs.append(paddle.zeros(cache_k_shape, dtype=cachekv_dtype)) self.cache_kvs.append(paddle.zeros(cache_v_shape, dtype=cachekv_dtype)) + if self.config.kv_cache_reuse: + logger.warning( + f"self.config.kv_cache_reuse = {self.config.kv_cache_reuse}, break, len(self.cache_kvs) = {len(self.cache_kvs)}" + ) + break else: # for mla's absorption assert self.cache_v_shapes is None diff --git a/paddlenlp/experimental/transformers/deepseek_v2/modeling.py b/paddlenlp/experimental/transformers/deepseek_v2/modeling.py index dca9623f0cb4..1db636a6ff29 100644 --- a/paddlenlp/experimental/transformers/deepseek_v2/modeling.py +++ b/paddlenlp/experimental/transformers/deepseek_v2/modeling.py @@ -47,9 +47,6 @@ yarn_get_mscale, yarn_linear_ramp_mask, ) -from paddlenlp.transformers.model_outputs import ( - BaseModelOutputWithPastAndCrossAttentions, -) from paddlenlp.transformers.model_utils import ( dy2st_nocheck_guard_context, register_base_model, @@ -266,7 +263,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str): self.weight_block_size = config.weight_block_size self.moe_quant_type = config.moe_quant_type self.rope_theta = config.rope_theta - self.return_full_hidden_states = config.get("return_full_hidden_states", False) self.use_weight_only = False self.weightonly_group_size = -1 @@ -591,7 +587,6 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str): speculate_config = SpeculateConfig( speculate_method=config.get("speculate_method", None), speculate_max_draft_token_num=config.get("speculate_max_draft_token_num", 5), - return_full_hidden_states=config.get("return_full_hidden_states", False), ) transformer_config = FusedMultiTransformerConfig( @@ -622,9 +617,9 @@ def __init__(self, config: DeepseekV2Config, base_model_prefix: str): rotary_emb=self.rotary_emb, norm_type="rmsnorm", rank_id=config.tensor_parallel_rank, + append_attn=config.append_attn, moe_config=moe_config, mla_config=mla_config, - append_attn=config.append_attn, speculate_config=speculate_config, ) @@ -1289,7 +1284,7 @@ def forward( inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[2]]) with dy2st_nocheck_guard_context(): - hidden_states, _ = self.transformer_block( + hidden_states, full_hidden_states = self.transformer_block( input_ids=input_ids, src=inputs_embeds, cum_offsets=cum_offsets, @@ -1301,13 +1296,7 @@ def forward( ) hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=None, - hidden_states=None, - attentions=None, - cum_offsets=cum_offsets, - ) + return hidden_states, full_hidden_states @register_base_model @@ -1967,7 +1956,7 @@ def forward( inputs_embeds = self.eh_proj(inputs_embeds) with dy2st_nocheck_guard_context(): - hidden_states, _ = self.transformer_block( + hidden_states, full_hidden_states = self.transformer_block( input_ids=input_ids, src=inputs_embeds, cum_offsets=cum_offsets, @@ -1980,12 +1969,7 @@ def forward( ) hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=None, - hidden_states=None, - attentions=None, - ) + return hidden_states, full_hidden_states class DeepseekV2ForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, DeepseekV2PretrainedModel): @@ -2212,7 +2196,7 @@ def forward( draft_tokens=None, output_padding_offset=None, ): - outputs = self.deepseek_v2( + hidden_states, full_hidden_states = self.deepseek_v2( input_ids, inputs_embeds=inputs_embeds, src_mask=src_mask, @@ -2230,21 +2214,7 @@ def forward( draft_tokens=draft_tokens, output_padding_offset=output_padding_offset, ) - if self.return_full_hidden_states: - from paddlenlp_ops import rebuild_padding_v2 - - full_hidden_states = outputs[0] - cum_offsets = outputs[1] - hidden_states = rebuild_padding_v2( - full_hidden_states, - cum_offsets, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - self.max_seq_len, - ) - else: - hidden_states = outputs[0] + logits = self.lm_head( hidden_states, tensor_parallel_output=False, @@ -2254,8 +2224,6 @@ def forward( else: return logits - return logits - @paddle.no_grad() def set_state_dict(self, state_dict): if "lm_head.weight" in state_dict: @@ -2363,7 +2331,7 @@ def forward( output_padding_offset=None, pre_hidden_states=None, ): - outputs = self.mtp( + hidden_states, _ = self.mtp( input_ids, src_mask=src_mask, caches=caches, @@ -2382,8 +2350,6 @@ def forward( pre_hidden_states=pre_hidden_states, ) - hidden_states = outputs[0] - logits = self.lm_head( hidden_states, tensor_parallel_output=False, diff --git a/paddlenlp/experimental/transformers/fused_transformer_layers.py b/paddlenlp/experimental/transformers/fused_transformer_layers.py index 3276b9ac9663..22b7c7c78898 100644 --- a/paddlenlp/experimental/transformers/fused_transformer_layers.py +++ b/paddlenlp/experimental/transformers/fused_transformer_layers.py @@ -1481,7 +1481,8 @@ def forward( self.pre_process(**kwargs) kwargs["cum_offsets"] = cum_offsets - if caches is not None: + kv_cache_reuse = kwargs.get("kv_cache_reuse", None) + if caches is not None and kv_cache_reuse is None: assert len(caches) == len(self.linear_weights) or len(caches) == 2 * len(self.linear_weights) assert self.num_layers == len(self.linear_weights) @@ -1589,7 +1590,7 @@ def forward( kwargs["input_ids"] = input_ids out = self.post_process(**kwargs) - return out, caches + return out, kwargs["multi_block_output"] class FusedMultiTransformerPostLayernorm(FusedMultiTransformerBase): @@ -3172,10 +3173,17 @@ def compute_attn( k_dequant_scales = kwargs.get("k_dequant_scales", None) v_dequant_scales = kwargs.get("v_dequant_scales", None) + kv_cache_reuse = kwargs.get("kv_cache_reuse", None) + if kv_cache_reuse: + k_cache_index = 0 + v_cache_index = 1 + else: + k_cache_index = 2 * i + v_cache_index = 2 * i + 1 fmha_out = paddle.incubate.nn.functional.block_multihead_attention( qkv_out, - caches[2 * i], - caches[2 * i + 1], + caches[k_cache_index], + caches[v_cache_index], kwargs.get("seq_lens_encoder", None), kwargs.get("seq_lens_decoder", None), kwargs.get("seq_lens_this_time", None), diff --git a/paddlenlp/experimental/transformers/generation_utils.py b/paddlenlp/experimental/transformers/generation_utils.py index 0a09ceaa1f9c..f93102d0fde5 100644 --- a/paddlenlp/experimental/transformers/generation_utils.py +++ b/paddlenlp/experimental/transformers/generation_utils.py @@ -295,7 +295,7 @@ def _forward_(**args): model_inputs = self.prepare_inputs_for_generation(input_ids, cache_kvs, **args) return self(**model_inputs) - def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): + def _post_process_(logits, top_p, temperature, step_idx_ori, model_kwargs): cache = model_kwargs.get("cache", None) just_decoder = model_kwargs["seq_len_encoder"] == 0 if cache is None: # first decoder @@ -314,7 +314,6 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): step_idx, model_kwargs["stop_flags"], ) - logits = outputs[0] if isinstance(outputs, tuple) else outputs logits = paddle.cast(logits, paddle.float32) logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori) @@ -373,7 +372,7 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): outputs = _forward_(**model_kwargs) # first decoder next_tokens, model_kwargs = _post_process_( - outputs, + outputs[0] if isinstance(outputs, tuple) else outputs, top_p, temperature, step_idx_ori, @@ -389,8 +388,9 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): paddle.sum(paddle.cast(model_kwargs["stop_flags"], "int64")), model_kwargs["stop_nums"], ): + outputs = _forward_(**model_kwargs) next_tokens, model_kwargs = _post_process_( - _forward_(**model_kwargs), + outputs[0] if isinstance(outputs, tuple) else outputs, top_p, temperature, step_idx_ori, @@ -692,7 +692,7 @@ def _forward_(**args): return self(**model_inputs) def _post_process_( - outputs, + logits, top_k, top_p, penalty_score, @@ -702,7 +702,7 @@ def _post_process_( model_kwargs, ): step_idx = model_kwargs["step_idx"] - logits = paddle.cast(outputs, paddle.float32) + logits = paddle.cast(logits, paddle.float32) from paddlenlp_ops import set_preids_token_penalty_multi_scores @@ -777,7 +777,7 @@ def _post_process_( outputs = _forward_(**model_kwargs) # [bs, 1, dim_embed] # first decoder next_tokens = _post_process_( - outputs, + outputs[0] if isinstance(outputs, tuple) else outputs, top_k, top_p, penalty_score, @@ -806,7 +806,7 @@ def _forward_(**args): return self(**model_inputs) def _post_process_( - outputs, + logits, top_k, top_p, penalty_score, @@ -816,7 +816,7 @@ def _post_process_( model_kwargs, ): step_idx = model_kwargs["step_idx"] - logits = paddle.cast(outputs, paddle.float32) + logits = paddle.cast(logits, paddle.float32) from paddlenlp_ops import speculate_get_token_penalty_multi_scores @@ -959,7 +959,7 @@ def _forward_(**args): return self(**model_inputs) def _post_process_( - outputs, + logits, top_k, top_p, penalty_score, @@ -968,7 +968,7 @@ def _post_process_( temperature, model_kwargs, ): - logits = paddle.cast(outputs, paddle.float32) + logits = paddle.cast(logits, paddle.float32) probs = F.softmax(logits) @@ -1191,7 +1191,7 @@ def _forward_(**args): model_inputs = self.prepare_inputs_for_generation(input_ids, **args) return self(**model_inputs) - def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): + def _post_process_(logits, top_p, temperature, step_idx_ori, model_kwargs): cache = model_kwargs.get("cache", None) just_decoder = model_kwargs["seq_len_encoder"] == 0 if cache is None: # first decoder @@ -1211,7 +1211,6 @@ def _post_process_(outputs, top_p, temperature, step_idx_ori, model_kwargs): step_idx, model_kwargs["stop_flags"], ) - logits = outputs[0] if isinstance(outputs, tuple) else outputs logits = paddle.cast(logits, paddle.float32) logits = logits_processors(model_kwargs["all_input_ids"], logits, decoding_step=step_idx_ori) diff --git a/paddlenlp/experimental/transformers/llama/modeling.py b/paddlenlp/experimental/transformers/llama/modeling.py index 8ada8a726b86..3a32e3c7ebda 100644 --- a/paddlenlp/experimental/transformers/llama/modeling.py +++ b/paddlenlp/experimental/transformers/llama/modeling.py @@ -257,19 +257,12 @@ def forward( # merge batch and seq_len dimension. inputs_embeds = inputs_embeds.reshape([batch * seq_len, hidden_dim]) - output_hidden_states = ( - output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states - ) - - return_dict = return_dict if return_dict is not None else self.config.use_return_dict - if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds # decoder layers - all_hidden_states = () if output_hidden_states else None with dy2st_nocheck_guard_context(): hidden_states = self.transformer_block( input_ids=input_ids, @@ -280,11 +273,7 @@ def forward( ) hidden_states = self.norm(hidden_states) - if output_hidden_states: - all_hidden_states = all_hidden_states + (hidden_states,) - - if not return_dict: - return tuple(v for v in [hidden_states, None, all_hidden_states, None] if v is not None) + return hidden_states @paddle.no_grad() # avx @@ -401,7 +390,6 @@ def __init__(self, config: LlamaConfig): self.epsilon = config.rms_norm_eps self.max_position_embeddings = config.max_position_embeddings self.quant_type = config.get("quant_type", "") - self.return_full_hidden_states = config.get("return_full_hidden_states", False) self.rope_theta = config.rope_theta self.use_neox = True @@ -617,7 +605,6 @@ def __init__(self, config: LlamaConfig): speculate_config = SpeculateConfig( speculate_method=config.get("speculate_method", None), speculate_max_draft_token_num=config.get("speculate_max_draft_token_num", 5), - return_full_hidden_states=config.get("return_full_hidden_states", False), ) hpu_config = HpuConfig( @@ -1496,7 +1483,7 @@ def forward( inputs_embeds = self.embed_tokens(ids_remove_padding) with dy2st_nocheck_guard_context(): - hidden_states, _ = self.transformer_block( + hidden_states, full_hidden_states = self.transformer_block( input_ids=input_ids, src=inputs_embeds, cum_offsets=cum_offsets, @@ -1508,17 +1495,11 @@ def forward( ) hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=None, - hidden_states=None, - attentions=None, - cum_offsets=cum_offsets, - ) + return hidden_states, full_hidden_states @register_base_model -class EagleForLlamaInferenceModel(LlamaBlockInferenceModel): +class EagleForLlamaBlockInferenceModel(LlamaBlockInferenceModel): def __init__(self, config: LlamaConfig): self.append_attn = config.append_attn super().__init__(config) @@ -2078,8 +2059,9 @@ def forward( excess_blocks=None, draft_tokens=None, output_padding_offset=None, + **kwargs, ): - outputs = self.llama( + hidden_states, full_hidden_states = self.llama( input_ids, src_mask=src_mask, caches=caches, @@ -2096,24 +2078,9 @@ def forward( excess_blocks=excess_blocks, draft_tokens=draft_tokens, output_padding_offset=output_padding_offset, + **kwargs, ) - # hidden_states = outputs[0] - if self.return_full_hidden_states: - from paddlenlp_ops import rebuild_padding_v2 - - # full_hidden_states = outputs[1] - full_hidden_states = outputs[0] - cum_offsets = outputs[1] - hidden_states = rebuild_padding_v2( - full_hidden_states, - cum_offsets, - seq_lens_decoder, - seq_lens_encoder, - output_padding_offset, - self.max_seq_len, - ) - else: - hidden_states = outputs[0] + logits = self.lm_head( hidden_states, tensor_parallel_output=False, @@ -2139,7 +2106,7 @@ def __init__(self, config): self.verify_window = config.get("speculate_verify_window", 2) self.max_seq_len = config.max_seq_len - self.eagle = EagleForLlamaInferenceModel(config) + self.eagle = EagleForLlamaBlockInferenceModel(config) if config.tie_word_embeddings: self.lm_head = LlamaLMHead(config, embedding_weights=self.llama.embed_tokens.weight, transpose_y=True) self.tie_weights() diff --git a/paddlenlp/experimental/transformers/qwen2/modeling.py b/paddlenlp/experimental/transformers/qwen2/modeling.py index 141bd4df1bc1..ab1a5ae04c74 100644 --- a/paddlenlp/experimental/transformers/qwen2/modeling.py +++ b/paddlenlp/experimental/transformers/qwen2/modeling.py @@ -54,7 +54,6 @@ from paddlenlp.transformers.conversion_utils import split_param_func from paddlenlp.transformers.model_outputs import ( # CausalLMOutputWithCrossAttentions, BaseModelOutputWithPast, - BaseModelOutputWithPastAndCrossAttentions, CausalLMOutputWithPast, ) from paddlenlp.transformers.model_utils import ( @@ -1325,7 +1324,7 @@ def forward( inputs_embeds = inputs_embeds.reshape([-1, inputs_embeds.shape[2]]) with dy2st_nocheck_guard_context(): - hidden_states, _ = self.transformer_block( + hidden_states, full_hidden_states = self.transformer_block( input_ids=input_ids, src=inputs_embeds, cum_offsets=cum_offsets, @@ -1337,12 +1336,7 @@ def forward( ) hidden_states = self.norm(hidden_states) - return BaseModelOutputWithPastAndCrossAttentions( - last_hidden_state=hidden_states, - past_key_values=None, - hidden_states=None, - attentions=None, - ) + return hidden_states, full_hidden_states class Qwen2ForCausalLMBlockInferenceModel(GenerationBlockInferenceModel, Qwen2PretrainedModel): @@ -1359,6 +1353,7 @@ def __init__(self, config: Qwen2Config, base_model_prefix: str = "qwen2"): self.max_candidate_len = config.get("speculate_max_candidate_len", 5) self.verify_window = config.get("speculate_verify_window", 2) self.max_seq_len = config.max_seq_len + self.return_full_hidden_states = config.get("return_full_hidden_states", False) self.qwen2 = Qwen2BlockInferenceModel(config, base_model_prefix) if config.tie_word_embeddings: @@ -1555,7 +1550,7 @@ def forward( draft_tokens=None, output_padding_offset=None, ): - outputs = self.qwen2( + hidden_states, full_hidden_states = self.qwen2( input_ids, inputs_embeds=inputs_embeds, src_mask=src_mask, @@ -1575,13 +1570,15 @@ def forward( output_padding_offset=output_padding_offset, ) - hidden_states = outputs[0] logits = self.lm_head( hidden_states, tensor_parallel_output=False, ) - return logits + if self.return_full_hidden_states: + return logits, full_hidden_states + else: + return logits @paddle.no_grad() def set_state_dict(self, state_dict): diff --git a/paddlenlp/transformers/model_outputs.py b/paddlenlp/transformers/model_outputs.py index d14f40e9a215..9836aff252fe 100644 --- a/paddlenlp/transformers/model_outputs.py +++ b/paddlenlp/transformers/model_outputs.py @@ -662,10 +662,6 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): Attentions weights of the decoder's cross-attention layer, after the attention softmax, used to compute the weighted average in the cross-attention heads. - cum_offsets (`tuple(paddle.Tensor)`, *optional*, needed when `return_full_hidden_states=True`: - Tuple of `paddle.Tensor` (one for each layer) of shape `(batch_size, 1)`. - - Offset of the current batch. """ last_hidden_state: paddle.Tensor = None @@ -673,7 +669,6 @@ class BaseModelOutputWithPastAndCrossAttentions(ModelOutput): hidden_states: Optional[Tuple[paddle.Tensor]] = None attentions: Optional[Tuple[paddle.Tensor]] = None cross_attentions: Optional[Tuple[paddle.Tensor]] = None - cum_offsets: Optional[Tuple[paddle.Tensor]] = None @dataclass diff --git a/slm/pipelines/examples/contrastive_training/README.md b/slm/pipelines/examples/contrastive_training/README.md index d3148786ec06..abc0ec68b22d 100644 --- a/slm/pipelines/examples/contrastive_training/README.md +++ b/slm/pipelines/examples/contrastive_training/README.md @@ -178,6 +178,9 @@ python -u evaluation/eval_mteb.py \ - `padding_side`:设置 padding 的位置,可取 left 或 right - `add_bos_token`:是否添加起始符,0表示不添加,1表示添加 - `add_eos_token`:是否添加结束符,0表示不添加,1表示添加 +- `quant_type`:是否使用量化加载,可选项包括 weight_only_int8,weight_only_int4,no,默认为 no,即不进行量化 +- `kv_cache_reuse`: 量化加载时,是否仅预分配首层 kv_cache 并重复利用,0 表示不复用,1 表示复用,默认为 0,此策略可降低量化加载时显存占用 + # MTEB 评估 [MTEB](https://github.com/embeddings-benchmark/mteb) diff --git a/slm/pipelines/examples/contrastive_training/evaluation/eval_mteb.py b/slm/pipelines/examples/contrastive_training/evaluation/eval_mteb.py index eb676d23f8f2..64a1b03c645d 100644 --- a/slm/pipelines/examples/contrastive_training/evaluation/eval_mteb.py +++ b/slm/pipelines/examples/contrastive_training/evaluation/eval_mteb.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. - import argparse import logging import mteb from datasets import load_dataset +from modelling_quant import HiddenPredictorWrapper from mteb import MTEB from mteb.abstasks.AbsTaskRetrieval import AbsTaskRetrieval, HFDataLoader from mteb.abstasks.TaskMetadata import TaskMetadata @@ -108,6 +108,8 @@ def get_args(): parser.add_argument("--padding_side", default="left", type=str) # right, left parser.add_argument("--add_bos_token", default=0, type=int) parser.add_argument("--add_eos_token", default=1, type=int) + parser.add_argument("--quant_type", default="no", type=str) + parser.add_argument("--kv_cache_reuse", default=0, type=int) return parser.parse_args() @@ -188,30 +190,46 @@ def get_args(): tokenizer.add_bos_token = bool(args.add_bos_token) tokenizer.add_eos_token = bool(args.add_eos_token) - encode_model = BiEncoderModel( - model_name_or_path=args.base_model_name_or_path, - normalized=True, - sentence_pooling_method=args.pooling_method, - query_instruction=args.query_instruction, - tokenizer=tokenizer, - eval_batch_size=args.eval_batch_size, - max_seq_length=args.max_seq_length, - model_flag=args.model_flag, - dtype=args.dtype, - ) + if args.quant_type != "no": + encode_model = HiddenPredictorWrapper( + model_name_or_path=args.base_model_name_or_path, + normalized=True, + sentence_pooling_method=args.pooling_method, + query_instruction=args.query_instruction, + tokenizer=tokenizer, + eval_batch_size=args.eval_batch_size, + max_seq_length=args.max_seq_length, + model_flag=args.model_flag, + dtype=args.dtype, + quant_type=args.quant_type, + kv_cache_reuse=args.kv_cache_reuse, + ) - if args.peft_model_name_or_path: - lora_config = LoRAConfig.from_pretrained(args.peft_model_name_or_path) - lora_config.merge_weights = True - encode_model.config = ( - encode_model.model_config - ) # for NV-Embed, this is no needed, but for repllama, this is needed - encode_model.config.tensor_parallel_degree = 1 - encode_model = LoRAModel.from_pretrained( - encode_model, args.peft_model_name_or_path, lora_config=lora_config, dtype=lora_config.dtype + else: + encode_model = BiEncoderModel( + model_name_or_path=args.base_model_name_or_path, + normalized=True, + sentence_pooling_method=args.pooling_method, + query_instruction=args.query_instruction, + tokenizer=tokenizer, + eval_batch_size=args.eval_batch_size, + max_seq_length=args.max_seq_length, + model_flag=args.model_flag, + dtype=args.dtype, ) - encode_model.eval() + if args.peft_model_name_or_path: + lora_config = LoRAConfig.from_pretrained(args.peft_model_name_or_path) + lora_config.merge_weights = True + encode_model.config = ( + encode_model.model_config + ) # for NV-Embed, this is no needed, but for repllama, this is needed + encode_model.config.tensor_parallel_degree = 1 + encode_model = LoRAModel.from_pretrained( + encode_model, args.peft_model_name_or_path, lora_config=lora_config, dtype=lora_config.dtype + ) + + encode_model.eval() logger.info("Ready to eval") if args.task_name == "MSMARCOTITLE": @@ -226,7 +244,7 @@ def get_args(): evaluation = MTEB(tasks=mteb.get_tasks(tasks=[args.task_name])) evaluation.run( encode_model, - output_folder=f"{args.output_folder}/{args.task_name}/{args.pooling_method}", + output_folder=f"{args.output_folder}/{args.task_name}/{args.quant_type}/{args.pooling_method}", score_function="dot", eval_splits=[args.task_split], ) diff --git a/slm/pipelines/examples/contrastive_training/evaluation/modelling_quant.py b/slm/pipelines/examples/contrastive_training/evaluation/modelling_quant.py new file mode 100644 index 000000000000..39124a7c2f1e --- /dev/null +++ b/slm/pipelines/examples/contrastive_training/evaluation/modelling_quant.py @@ -0,0 +1,396 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import dataclasses +import os +import sys +from typing import Dict, List, Optional, Union + +import numpy as np +import paddle +import paddle.incubate.multiprocessing as mp +from paddle.distributed import fleet +from tqdm import tqdm + +from paddlenlp.transformers import AutoConfig +from paddlenlp.trl import llm_utils +from paddlenlp.utils.log import logger + +current_script_dir = os.path.dirname(os.path.abspath(__file__)) +project_root = os.path.dirname(current_script_dir) +if project_root not in sys.path: + sys.path.insert(0, project_root) + +from llm.predict.predictor import ( + DygraphBlockInferencePredictor, + ModelArgument, + PredictorArgument, + PretrainedModel, + PretrainedTokenizer, +) +from paddlenlp.transformers import ( + AutoInferenceModelForCausalLM, + Llama3Tokenizer, + LlamaTokenizer, +) +from paddlenlp.utils.env import MAX_BSZ, MAX_DRAFT_TOKENS, SPECULATE_MAX_BSZ + +MODEL_FLAG = "" +MAX_SEQ_LENGTH = 0 +QUERY_DOC_FLAG_FOR_LLARA = "" + + +class DygraphBlockInferenceHiddenPredictor(DygraphBlockInferencePredictor): + def __init__( + self, config: PredictorArgument, tokenizer: PretrainedTokenizer = None, model: PretrainedModel = None, **kwargs + ): + super().__init__(config, tokenizer, model, **kwargs) + + @paddle.no_grad() + def encode(self, sentences: list[str]): + + if MODEL_FLAG == "llara": + logger.warning('MODEL_FLAG == "llara"') + sentences = self.preprocess_sentences_for_llara(sentences, QUERY_DOC_FLAG_FOR_LLARA) + + total = 0 + all_embeddings = [] + for start_index in tqdm(range(0, len(sentences), self.config.batch_size), desc="Batches"): + sentences_batch = sentences[start_index : start_index + self.config.batch_size] + + self._preprocess(sentences_batch) + if self.proposer is not None: + self.proposer.insert_query( + base_model_inputs=self.model_inputs, real_bs=len(sentences_batch), seq_lens=self.seq_lens + ) + result_queue = mp.Queue() + tensor_queue = mp.Queue() + done_event = mp.Event() + + output_tensor_shape = ( + [SPECULATE_MAX_BSZ * MAX_DRAFT_TOKENS + SPECULATE_MAX_BSZ + 2, 1] + if self.proposer + else [MAX_BSZ + 2, 1] + ) + read_res_func = llm_utils.speculate_read_res if self.proposer else llm_utils.read_res + + read_res_process = mp.Process( + target=read_res_func, args=[self.model_name_or_path, tensor_queue, result_queue, done_event] + ) + if self.tensor_parallel_rank == 0: + read_res_process.start() + + output_tensor = paddle.full(shape=output_tensor_shape, fill_value=2, dtype="int64").cpu() + tensor_queue.put(output_tensor) + + if self.tensor_parallel_rank == 0: + done_event.wait() + + if self.proposer is not None: + self.proposer.run( + self.model_inputs, + # real_batch_size=self.batch_size, + real_batch_size=len(sentences_batch), + seq_lens_this_time=self.model_inputs["seq_lens_this_time"], + base_model_full_hidden_states=self.full_hidden_states, + ) + + inputs = self.model_inputs + + _, full_hidden_states = self.model( + input_ids=inputs["input_ids"], + seq_lens_this_time=inputs["seq_lens_this_time"], + caches=inputs["cache_kvs"], + seq_lens_encoder=inputs["seq_lens_encoder"], + seq_lens_decoder=inputs["seq_lens_decoder"], + block_tables=inputs["block_tables"], + rope_emb=inputs["rope_emb"], + kv_cache_reuse=self.config.kv_cache_reuse, + ) + + last_hidden_state_tensor = self.split_hidden_states_by_seq_lens( + full_hidden_states, inputs["seq_lens_this_time"] + ) + total += last_hidden_state_tensor.shape[0] + + assert last_hidden_state_tensor.shape[0] == len( + sentences_batch + ), f"Output batch size mismatch: {last_hidden_state_tensor.shape[0]} vs {len(sentences_batch)}" + assert ( + last_hidden_state_tensor.shape[1] == self.model.config.hidden_size + ), f"Hidden size mismatch: {last_hidden_state_tensor.shape[1]} vs {self.model.config.hidden_size}" + + if self.config.normalized: + embeddings = paddle.nn.functional.normalize(last_hidden_state_tensor, p=2, axis=-1) + + all_embeddings.append(embeddings.cpu().numpy().astype("float32")) + + return np.concatenate(all_embeddings, axis=0) + + def split_hidden_states_by_seq_lens(self, hidden_states, seq_lens_this_time): + """ + Args: + hidden_states (Tensor): shape [total_seq_len, hidden_size], e.g. [135, 2048] + seq_lens_this_time (Tensor): shape [batch_size, 1], e.g. [[127], [8]] + + Returns: + Tensor: shape [batch_size, hidden_size] + """ + if hasattr(seq_lens_this_time, "numpy"): # Paddle tensor + seq_lens = seq_lens_this_time.numpy().flatten().tolist() + else: + seq_lens = [x[0] if isinstance(x, list) else x for x in seq_lens_this_time] + + if self.config.sentence_pooling_method == "last": + if self.config.tokenizer.padding_side == "right": + split_hidden_states = [] + start = 0 + for length in seq_lens: + end = start + length - 1 + split_hidden_states.append(hidden_states[end]) + start = start + length + + elif self.config.sentence_pooling_method == "last_8": + split_hidden_states = [] + start = 0 + for length in seq_lens: + end = start + length - 1 + split_hidden_states.append(paddle.mean(hidden_states[end - 7 : end + 1], axis=0)) + start = start + length + + else: + raise f"the sentence_pooling_method {self.config.sentence_pooling_method} is not supported" + return paddle.stack(split_hidden_states, axis=0) # shape: [batch_size, hidden_size] + + def preprocess_sentences_for_llara(self, sentences: List[str], query_or_doc: str, **kwargs) -> List[str]: + + prefix = '"' + if query_or_doc == "query": + suffix = '", predict the following passage within eight words: ' + elif query_or_doc == "doc": + suffix = '", summarize the above passage within eight words: ' + else: + raise ValueError(f"Invalid query_or_doc: {query_or_doc}") + + logger.warning(f"query_or_doc: {query_or_doc}") + + sentences_after_process = [] + import tqdm + + for sentence in tqdm.tqdm(sentences, desc="preprocess_sentences_for_llara"): + inputs = self.tokenizer( + sentence, + return_tensors=None, + max_length=MAX_SEQ_LENGTH - 20, + truncation=True, + add_special_tokens=False, + ) + sentences_after_process.append(self.tokenizer.decode(inputs["input_ids"], skip_special_tokens=True)) + + sentences_after_process = [prefix + " " + sentence + " " + suffix for sentence in sentences_after_process] + + return sentences_after_process + + +class HiddenPredictorWrapper: + def __init__( + self, + model_name_or_path: str, + normalized: bool = True, + sentence_pooling_method: str = "last", + query_instruction: Optional[str] = None, + document_instruction: Optional[str] = None, + tokenizer=None, + eval_batch_size: int = 32, + max_seq_length: int = 512, + model_flag: str = None, + dtype: str = "float32", + quant_type: str = None, + kv_cache_reuse: bool = False, + ): + self.predictor_args = PredictorArgument() + self.model_args = ModelArgument() + + override_fields = { + "model_name_or_path": model_name_or_path, + "sentence_pooling_method": sentence_pooling_method, + "dtype": dtype, + "quant_type": quant_type, + "return_full_hidden_states": 1, + "inference_model": True, + "block_attn": True, + "batch_size": eval_batch_size, + "kv_cache_reuse": bool(kv_cache_reuse), + } + self.model_name_or_path = model_name_or_path + self.dtype = dtype + self.normalized = normalized + self.sentence_pooling_method = sentence_pooling_method + self.query_instruction = query_instruction + self.document_instruction = document_instruction + self.document_instruction = document_instruction + self.eval_batch_size = eval_batch_size + self.max_seq_length = max_seq_length + self.model_flag = model_flag + self.quant_type = quant_type + self.tokenizer = tokenizer + + for field in dataclasses.fields(self.predictor_args): + if field.name in override_fields and override_fields[field.name] is not None: + setattr(self.predictor_args, field.name, override_fields[field.name]) + + for field in dataclasses.fields(self.model_args): + if field.name in override_fields and override_fields[field.name] is not None: + setattr(self.model_args, field.name, override_fields[field.name]) + + self.predictor_args.tokenizer = self.tokenizer + self.predictor_args.sentence_pooling_method = self.sentence_pooling_method + self.predictor_args.normalized = self.normalized + self.predictor = self._create_predictor() + + def _create_predictor(self): + + model_config = AutoConfig.from_pretrained(self.predictor_args.model_name_or_path) + + llm_utils.set_triton_cache(self.predictor_args.model_name_or_path, self.predictor_args.mode) + try: + from paddle.utils import try_import + + try_import("paddlenlp_ops") + except ImportError: + logger.warning("paddlenlp_ops does not exist, please install paddlenlp_ops.") + return + tensor_parallel_degree = paddle.distributed.get_world_size() + if tensor_parallel_degree > 1: + strategy = fleet.DistributedStrategy() + strategy.hybrid_configs = { + "dp_degree": 1, + "mp_degree": tensor_parallel_degree, + "pp_degree": 1, + "sharding_degree": 1, + } + fleet.init(is_collective=True, strategy=strategy) + + paddle.set_device(self.predictor_args.device) + paddle.set_default_dtype(self.predictor_args.dtype) + from paddlenlp.utils.env import USE_FAST_TOKENIZER + + self.tokenizer.use_fast = USE_FAST_TOKENIZER + + # init chat_template for tokenizer + llm_utils.init_chat_template(self.tokenizer, self.model_name_or_path, self.predictor_args.chat_template) + tensor_parallel_rank, tensor_parallel_degree = llm_utils.init_dist_env() + # TODO(wj-Mcat): fix llama tokenzier pad_token bug + if (isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer))) and not self.tokenizer.pad_token: + self.tokenizer.pad_token = self.tokenizer.eos_token + + model = AutoInferenceModelForCausalLM.from_pretrained( + self.model_name_or_path, + config=model_config, + predictor_args=self.predictor_args, + model_args=self.model_args, + dtype=self.dtype, + tensor_parallel_degree=tensor_parallel_degree, + tensor_parallel_rank=tensor_parallel_rank, + ) + + predictor_class_name = ( + "DygraphBlockInferenceHiddenPredictor" # execute_mode + inference_mode + "Hidden" + "Predictor" + ) + + import_class = sys.modules[__name__] + predictor_class = getattr(import_class, predictor_class_name) + + cache_kvs_shape = None # used for not block_attn/append_attn + cache_k_shapes = None # used for block_attn/append_attn + cache_v_shapes = None # used for block_attn/append_attn + + predictor = predictor_class( + self.predictor_args, + tokenizer=self.tokenizer, + model=model, + cache_k_shapes=cache_k_shapes, + cache_v_shapes=cache_v_shapes, + cache_kvs_shape=cache_kvs_shape, + model_args=self.model_args, + ) + + return predictor + + def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray: + """ + This function will be used to encode queries for retrieval task + if there is a instruction for queries, we will add it to the query text + """ + + global MODEL_FLAG + global MAX_SEQ_LENGTH + global QUERY_DOC_FLAG_FOR_LLARA + MODEL_FLAG = self.model_flag + MAX_SEQ_LENGTH = self.max_seq_length + QUERY_DOC_FLAG_FOR_LLARA = "query" + + if self.query_instruction is not None: + input_texts = [f"{self.query_instruction}{query}" for query in queries] + else: + input_texts = queries + + assert isinstance(input_texts, list), "input_texts should be a list" + assert len(input_texts) == len(queries), f"Mismatch in number of queries: {len(input_texts)} vs {len(queries)}" + + encode_results = self.encode_sentences(input_texts=input_texts) + + assert isinstance(encode_results, np.ndarray), "encode_results should be a numpy array" + assert encode_results.shape[0] >= len( + input_texts + ), f"Encoded query count mismatch: {encode_results.shape[0]} vs {len(input_texts)}" + + return encode_results[: len(input_texts)] + + def encode_corpus(self, corpus: List[Union[Dict[str, str], str]], **kwargs) -> np.ndarray: + """ + This function will be used to encode corpus for retrieval task + if there is a instruction for docs, we will add it to the doc text + """ + + global MODEL_FLAG + global QUERY_DOC_FLAG_FOR_LLARA + MODEL_FLAG = self.model_flag + QUERY_DOC_FLAG_FOR_LLARA = "doc" + + if isinstance(corpus[0], dict): + if self.document_instruction is not None: + input_texts = [ + "{}{} {}".format(self.document_instruction, doc.get("title", ""), doc["text"]).strip() + for doc in corpus + ] + else: + input_texts = ["{} {}".format(doc.get("title", ""), doc["text"]).strip() for doc in corpus] + else: + if self.document_instruction is not None: + input_texts = [f"{self.document_instruction}{doc}" for doc in corpus] + else: + input_texts = corpus + + encode_results = self.encode_sentences(input_texts=input_texts) + assert encode_results.shape[0] >= len( + input_texts + ), f"Encoded query count mismatch: {encode_results.shape[0]} vs {len(input_texts)}" + + return encode_results[: len(input_texts)] + + def encode_sentences(self, input_texts): + encode_results = self.predictor.encode(input_texts) + + return encode_results diff --git a/tests/llm/test_predictor_v1.py b/tests/llm/test_predictor_v1.py index 98cd7b55c51d..c75ac81a729e 100644 --- a/tests/llm/test_predictor_v1.py +++ b/tests/llm/test_predictor_v1.py @@ -119,6 +119,7 @@ def setUp(self) -> None: ( { "append_attn": True, + "return_full_hidden_states": True, }, ), ]