Skip to content

Commit 6d538c7

Browse files
authored
Fix predictor bug (#8947)
* fix bug * update
1 parent 310a7bc commit 6d538c7

File tree

2 files changed

+3
-8
lines changed

2 files changed

+3
-8
lines changed

llm/predict/predictor.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -773,12 +773,8 @@ def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
773773

774774
self.dtype = config.dtype or self.model_config.dtype
775775

776-
try:
777-
self.rope_theta = self.model_config.rope_theta
778-
self.rope_scaling = self.model_config.rope_scaling
779-
except:
780-
self.rope_theta = 10000.0
781-
self.rope_scaling = None
776+
self.rope_theta = self.model_config.get("rope_theta", 10000.0)
777+
self.rope_scaling = self.model_config.get("rope_scaling", None)
782778

783779
self.pre_cache_length = 0
784780

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -980,7 +980,6 @@ def set_state_dict(self, state_dict):
980980
act_scale_map_dict = scale_map_dict["act_scale"]
981981
weight_scale_map_dict = scale_map_dict["weight_scale"]
982982
cache_scale_map_dict = scale_map_dict["cachekv_scale"]
983-
# TODO(RichardWooSJTU): support multi-cards
984983

985984
act_scale_json_path = os.path.join(self.quant_model_path, "act_scales.json")
986985
weight_scale_json_path = os.path.join(self.quant_model_path, "weight_scales.json")
@@ -1008,7 +1007,7 @@ def set_state_dict(self, state_dict):
10081007
cache_scale_json_path = os.path.join(self.quant_model_path, "cachekv_scales.json")
10091008
if self.config.tensor_parallel_degree > 1 and not self.config.single_card_ptq:
10101009
cache_scale_json_path = os.path.join(
1011-
self.quant_model_path, f"cachekv_act_scales_{self.config.tensor_parallel_rank}.json"
1010+
self.quant_model_path, f"cachekv_scales_{self.config.tensor_parallel_rank}.json"
10121011
)
10131012
cache_scales_loader = CacheScaleLoader(
10141013
cache_scale_json_path,

0 commit comments

Comments
 (0)