Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 67 additions & 5 deletions paddlenlp/trainer/utils/load_hf_ckpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,73 @@ def prepare_tensor(tensor, dst_shape, *, force_transpose=False):
if len(tensor.shape) != 1:
print("attention same shape not transpose !!!!!!!!!!!!!!!!!!!!!!")
return tensor
if len(tensor.shape) == 2 and paddle.transpose(tensor, perm=[1, 0]).contiguous().shape == dst_shape:

if len(tensor.shape) == 2:
num_experts, hidden_size = tensor.shape
assert hidden_size == dst_shape[0], f"Shape not match: {tensor.shape} {dst_shape}"
if num_experts != dst_shape[1]:
print(f"Slice weight: {tensor.shape} -> {dst_shape}")
tensor = tensor[:dst_shape[1]]
return paddle.transpose(tensor, perm=[1, 0]).contiguous()

print("shape not match here")
if len(tensor.shape) == 1:
print(f"Slice weight: {tensor.shape} -> {dst_shape}")
tensor = tensor[:dst_shape[0]]
return tensor

print("Fatal: shape not match here:", tensor.shape, dst_shape)
sys.exit()


def hf_cache(path):
print('looking up:', path)
import os, time, subprocess
basename = os.path.basename(path)
cache_path = os.path.join('/dev/shm', 'lshrun_' + basename)
disk_cache_path = os.path.join('/root/paddlejob/tmpspace/liangshuhao', basename)
lock_path = cache_path + '.lock'
begin = time.time()

# Case 1: 文件在内存中
if os.path.exists(cache_path):
print('hit mem cache:', cache_path)
return cache_path

# Case 2: 文件在磁盘中
if os.path.exists(disk_cache_path):
print('hit disk cache:', disk_cache_path)
return disk_cache_path

# Case 3: 等待其他进程将文件搬运到内存
try:
open(lock_path, 'x')
except FileExistsError:
print('waiting peer load:', cache_path)
while not os.path.exists(cache_path):
time.sleep(0.1)
print('peer done:', cache_path)
return cache_path

# Case 4: 从其他机器的磁盘中取回
ckpt_id = int(basename.split('-')[1])
dst_rank = ckpt_id % int(os.environ['TRAINERS_NUM'])
dst_ip = os.environ['TRIANER_IP_LIST'].split(',')[dst_rank]
print('fetching:', f'root@{dst_ip}:{disk_cache_path}', '->', lock_path)
if subprocess.run(['scp', f'root@{dst_ip}:{disk_cache_path}', lock_path]).returncode == 0:
subprocess.run(['mv', lock_path, cache_path], check=True)
print(f'done fetch in {time.time() - begin:.3f}s:', cache_path)
return cache_path

# Case 5: 从源地址取回
print('copying:', path, '->', lock_path)
while subprocess.run(['cp', path, lock_path]).returncode:
print('retrying:', path, '->', lock_path)
time.sleep(10) # sometimes too many open files cause error
subprocess.run(['mv', lock_path, cache_path], check=True)
print(f'done copy in {time.time() - begin:.3f}s:', cache_path)
return cache_path


def load_huggingface_ckpt(model, huggingface_ckpt_path):
ckpt_pre = huggingface_ckpt_path

Expand Down Expand Up @@ -328,8 +388,9 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
check_list = []
print("Start load huggingface ckpt")
for i, filename in enumerate(required_files):
print(f'loading {i + 1}/{len(required_files)}: {filename}')
try:
with safe_open(ckpt_pre + filename, framework="paddle", device="cpu") as f:
with safe_open(hf_cache(ckpt_pre + filename), framework="paddle", device="cpu") as f:
# 加载该文件包含的所有参数
pd_params = file_to_pd_param_name[filename]
for pd_param in pd_params:
Expand Down Expand Up @@ -359,12 +420,12 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
if weight_map[hf_name[0]] == filename:
tensor0 = f.get_tensor(hf_name[0])
with safe_open(
ckpt_pre + weight_map[hf_name[1]], framework="paddle", device="cpu"
hf_cache(ckpt_pre + weight_map[hf_name[1]]), framework="paddle", device="cpu"
) as f_other:
tensor1 = f_other.get_tensor(hf_name[1])
else:
with safe_open(
ckpt_pre + weight_map[hf_name[0]], framework="paddle", device="cpu"
hf_cache(ckpt_pre + weight_map[hf_name[0]]), framework="paddle", device="cpu"
) as f_other:
tensor0 = f_other.get_tensor(hf_name[0])
tensor1 = f.get_tensor(hf_name[1])
Expand All @@ -376,3 +437,4 @@ def load_huggingface_ckpt(model, huggingface_ckpt_path):
except Exception as e:
print(f"Error loading {filename}: {str(e)}")
raise
print("End load huggingface ckpt")
Loading