diff --git a/paddlenlp/trainer/utils/load_hf_ckpt.py b/paddlenlp/trainer/utils/load_hf_ckpt.py index c0df004428ef..ffea9320098a 100644 --- a/paddlenlp/trainer/utils/load_hf_ckpt.py +++ b/paddlenlp/trainer/utils/load_hf_ckpt.py @@ -277,13 +277,73 @@ def prepare_tensor(tensor, dst_shape, *, force_transpose=False): if len(tensor.shape) != 1: print("attention same shape not transpose !!!!!!!!!!!!!!!!!!!!!!") return tensor - if len(tensor.shape) == 2 and paddle.transpose(tensor, perm=[1, 0]).contiguous().shape == dst_shape: + + if len(tensor.shape) == 2: + num_experts, hidden_size = tensor.shape + assert hidden_size == dst_shape[0], f"Shape not match: {tensor.shape} {dst_shape}" + if num_experts != dst_shape[1]: + print(f"Slice weight: {tensor.shape} -> {dst_shape}") + tensor = tensor[:dst_shape[1]] return paddle.transpose(tensor, perm=[1, 0]).contiguous() - print("shape not match here") + if len(tensor.shape) == 1: + print(f"Slice weight: {tensor.shape} -> {dst_shape}") + tensor = tensor[:dst_shape[0]] + return tensor + + print("Fatal: shape not match here:", tensor.shape, dst_shape) sys.exit() +def hf_cache(path): + print('looking up:', path) + import os, time, subprocess + basename = os.path.basename(path) + cache_path = os.path.join('/dev/shm', 'lshrun_' + basename) + disk_cache_path = os.path.join('/root/paddlejob/tmpspace/liangshuhao', basename) + lock_path = cache_path + '.lock' + begin = time.time() + + # Case 1: 文件在内存中 + if os.path.exists(cache_path): + print('hit mem cache:', cache_path) + return cache_path + + # Case 2: 文件在磁盘中 + if os.path.exists(disk_cache_path): + print('hit disk cache:', disk_cache_path) + return disk_cache_path + + # Case 3: 等待其他进程将文件搬运到内存 + try: + open(lock_path, 'x') + except FileExistsError: + print('waiting peer load:', cache_path) + while not os.path.exists(cache_path): + time.sleep(0.1) + print('peer done:', cache_path) + return cache_path + + # Case 4: 从其他机器的磁盘中取回 + ckpt_id = int(basename.split('-')[1]) + dst_rank = ckpt_id % int(os.environ['TRAINERS_NUM']) + dst_ip = os.environ['TRIANER_IP_LIST'].split(',')[dst_rank] + print('fetching:', f'root@{dst_ip}:{disk_cache_path}', '->', lock_path) + if subprocess.run(['scp', f'root@{dst_ip}:{disk_cache_path}', lock_path]).returncode == 0: + subprocess.run(['mv', lock_path, cache_path], check=True) + print(f'done fetch in {time.time() - begin:.3f}s:', cache_path) + return cache_path + + # Case 5: 从源地址取回 + print('copying:', path, '->', lock_path) + while subprocess.run(['cp', path, lock_path]).returncode: + print('retrying:', path, '->', lock_path) + time.sleep(10) # sometimes too many open files cause error + subprocess.run(['mv', lock_path, cache_path], check=True) + print(f'done copy in {time.time() - begin:.3f}s:', cache_path) + return cache_path + + def load_huggingface_ckpt(model, huggingface_ckpt_path): ckpt_pre = huggingface_ckpt_path @@ -328,8 +388,9 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path): check_list = [] print("Start load huggingface ckpt") for i, filename in enumerate(required_files): + print(f'loading {i + 1}/{len(required_files)}: {filename}') try: - with safe_open(ckpt_pre + filename, framework="paddle", device="cpu") as f: + with safe_open(hf_cache(ckpt_pre + filename), framework="paddle", device="cpu") as f: # 加载该文件包含的所有参数 pd_params = file_to_pd_param_name[filename] for pd_param in pd_params: @@ -359,12 +420,12 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path): if weight_map[hf_name[0]] == filename: tensor0 = f.get_tensor(hf_name[0]) with safe_open( - ckpt_pre + weight_map[hf_name[1]], framework="paddle", device="cpu" + hf_cache(ckpt_pre + weight_map[hf_name[1]]), framework="paddle", device="cpu" ) as f_other: tensor1 = f_other.get_tensor(hf_name[1]) else: with safe_open( - ckpt_pre + weight_map[hf_name[0]], framework="paddle", device="cpu" + hf_cache(ckpt_pre + weight_map[hf_name[0]]), framework="paddle", device="cpu" ) as f_other: tensor0 = f_other.get_tensor(hf_name[0]) tensor1 = f.get_tensor(hf_name[1]) @@ -376,3 +437,4 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path): except Exception as e: print(f"Error loading {filename}: {str(e)}") raise + print("End load huggingface ckpt")