diff --git a/paddlenlp/transformers/model_utils.py b/paddlenlp/transformers/model_utils.py index b5756896c65a..ac2286556743 100644 --- a/paddlenlp/transformers/model_utils.py +++ b/paddlenlp/transformers/model_utils.py @@ -428,12 +428,12 @@ def _load_part_state_dict( part_state_dict.update(quant_state_dict) else: if key in tensor_parallel_split_mapping: - if len(py_safe_slice_.shape) == 0: + if len(py_safe_slice_.get_shape()) == 0: weight = tensor_parallel_split_mapping[key](py_safe_slice_.get()) else: weight = tensor_parallel_split_mapping[key](py_safe_slice_) else: - if len(py_safe_slice_.shape) == 0: + if len(py_safe_slice_.get_shape()) == 0: weight = py_safe_slice_.get() else: weight = py_safe_slice_[:]