Skip to content

Commit 7c5fee8

Browse files
committed
fix: add shortgpt script, fix formatting and imports
1 parent b9d98fd commit 7c5fee8

File tree

3 files changed

+93
-80
lines changed

3 files changed

+93
-80
lines changed

slm/pipelines/examples/contrastive_training/evaluation/eval_mteb.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class MSMARCOTITLE(AbsTaskRetrieval):
3131
dataset={
3232
"corpus_path": "Tevatron/msmarco-passage-corpus-new",
3333
"path": "mteb/msmarco",
34-
"revision": "c5a29a104738b98a9e76336939199e264163d4a0",
34+
"revision": "c5a29a104738b98a9e76336939199e264163d4a0",
3535
},
3636
name="MSMARCOTITLE",
3737
description="MS MARCO is a collection of datasets focused on deep learning in search",
@@ -53,9 +53,9 @@ class MSMARCOTITLE(AbsTaskRetrieval):
5353
bibtex_citation=None,
5454
n_samples=None,
5555
avg_character_length=None,
56-
modalities = ["text"],
57-
sample_creation = "created",
58-
descriptive_stats = {}
56+
modalities=["text"],
57+
sample_creation="created",
58+
descriptive_stats={},
5959
)
6060

6161
def load_data(self, **kwargs):

slm/pipelines/examples/contrastive_training/evaluation/prediction.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,11 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
import os
15-
import sys
16-
1714
import numpy as np
1815
import paddle
1916

2017
from paddlenlp.data import DataCollatorWithPadding
21-
from paddlenlp.transformers import AutoTokenizer
22-
23-
from paddlenlp.transformers import BiEncoderModel
18+
from paddlenlp.transformers import AutoTokenizer, BiEncoderModel
2419

2520

2621
class Eval_model:
@@ -45,10 +40,7 @@ def _construct_model(self):
4540
"""
4641
if self.model_type in ["bert", "roberta", "ernie"]:
4742
self._model = BiEncoderModel(
48-
model_name_or_path=self.model,
49-
normalized=True,
50-
sentence_pooling_method="cls",
51-
dtype='float32'
43+
model_name_or_path=self.model, normalized=True, sentence_pooling_method="cls", dtype="float32"
5244
)
5345
print(f"loading checkpoints {self.model}")
5446
else:

slm/pipelines/examples/contrastive_training/shortgpt_prune.py

Lines changed: 87 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,33 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
115
import argparse
216
import json
317
import math
418
import os
519
import re
620
from collections import OrderedDict
7-
from shutil import copyfile
8-
from typing import List, Optional
21+
from typing import List
922

10-
import numpy as np
1123
import paddle
1224
from datasets import load_dataset
1325
from paddle.io import DataLoader
14-
from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer, NVEncodeModel
1526
from tqdm import tqdm
1627

28+
from paddlenlp.transformers import AutoModelForCausalLM, AutoTokenizer, NVEncodeModel
29+
30+
1731
# =====================================================================================
1832
# 1. block_influence
1933
# =====================================================================================
@@ -33,19 +47,21 @@ def block_influence(
3347
norm_output = paddle.norm(output_hidden_state, p=2, axis=-1, keepdim=True)
3448

3549
sim = paddle.matmul(input_hidden_state, output_hidden_state, transpose_y=True) / (norm_input * norm_output)
36-
sim = paddle.diag(sim).astype('float32').nan_to_num(nan=0.5)
50+
sim = paddle.diag(sim).astype("float32").nan_to_num(nan=0.5)
3751

3852
if angular:
3953
return paddle.acos(sim) / math.pi
4054
return 1 - sim
4155

56+
4257
# =====================================================================================
43-
# 2. ShortGPT
58+
# 2. ShortGPT
4459
# =====================================================================================
4560
class ShortGPT:
4661
"""
4762
A class to evaluate layer importance in LLMs using PaddlePaddle.
4863
"""
64+
4965
def __init__(self, model_name: str, layers_path: str):
5066
print(f"Loading tokenizer for '{model_name}'...")
5167
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -54,36 +70,30 @@ def __init__(self, model_name: str, layers_path: str):
5470
print(f"Loading model '{model_name}' with PaddlePaddle backend...")
5571
if "NV-Embed" in model_name:
5672
self.model = NVEncodeModel.from_pretrained(
57-
model_name,
58-
tokenizer_path=model_name,
59-
query_instruction = "",
60-
document_instruction =""
61-
)
62-
else:
63-
self.model = AutoModelForCausalLM.from_pretrained(
64-
model_name,
65-
dtype=paddle.float16
73+
model_name, tokenizer_path=model_name, query_instruction="", document_instruction=""
6674
)
67-
75+
else:
76+
self.model = AutoModelForCausalLM.from_pretrained(model_name, dtype=paddle.float16)
77+
6878
self.model.eval()
6979
print("Model loaded successfully for importance evaluation.")
70-
80+
7181
try:
72-
path_parts = layers_path.split('.') # e.g., 'llama.layers' -> ['llama', 'layers']
82+
path_parts = layers_path.split(".") # e.g., 'llama.layers' -> ['llama', 'layers']
7383

7484
self.base_model_for_call = self.model
7585
# 遍历路径中除了最后 'layers' 之外的部分 (e.g., 'llama')
7686
for part in path_parts[:-1]:
7787
self.base_model_for_call = getattr(self.base_model_for_call, part)
78-
88+
7989
# 从基础模型中获取 'layers' 列表
8090
self.layers = getattr(self.base_model_for_call, path_parts[-1])
8191
print(f"Successfully located base model for evaluation call: {type(self.base_model_for_call)}")
8292
print(f"Successfully located {len(self.layers)} layers.")
8393

8494
except AttributeError:
8595
raise AttributeError(f"Could not find layers at path '{layers_path}' in the model architecture.")
86-
96+
8797
self.importances = [0.0 for _ in self.layers]
8898

8999
def compute_bi(self, hiddens: List[paddle.Tensor]):
@@ -95,20 +105,15 @@ def compute_bi(self, hiddens: List[paddle.Tensor]):
95105
layer_index = i
96106
if layer_index < len(self.importances):
97107
in_hidden = hiddens[i]
98-
out_hidden = hiddens[i+n]
99-
self.importances[layer_index] += block_influence(
100-
in_hidden,
101-
out_hidden
102-
).sum().item()
108+
out_hidden = hiddens[i + n]
109+
self.importances[layer_index] += block_influence(in_hidden, out_hidden).sum().item()
103110

104111
@paddle.no_grad()
105112
def eval_importance(self, prompts: List[str], model_name: str, stride: int = 256):
106113
"""
107114
Evaluates the importance of model layers on given prompts.
108115
"""
109-
prompt_tokens = self.tokenizer(
110-
prompts, padding=True, return_attention_mask=True, return_tensors='pd'
111-
)
116+
prompt_tokens = self.tokenizer(prompts, padding=True, return_attention_mask=True, return_tensors="pd")
112117
input_ids = prompt_tokens.input_ids
113118
attn_mask = prompt_tokens.attention_mask
114119

@@ -117,32 +122,27 @@ def eval_importance(self, prompts: List[str], model_name: str, stride: int = 256
117122
for start in range(0, max_prompt_len, stride):
118123
seq_ids = (attn_mask.sum(axis=-1) > start).nonzero().squeeze()
119124
seq_ids = seq_ids.unsqueeze(0) if seq_ids.ndim == 0 else seq_ids
120-
125+
121126
if seq_ids.shape[0] == 0:
122127
continue
123128

124-
inputs = input_ids[seq_ids, start:start+stride]
125-
attn = attn_mask[seq_ids, start:start+stride]
129+
inputs = input_ids[seq_ids, start : start + stride]
130+
attn = attn_mask[seq_ids, start : start + stride]
126131

127132
if "NV-Embed" in model_name:
128133
outputs = self.base_model_for_call.m_forward(
129-
input_ids=inputs,
130-
attention_mask=attn,
131-
output_hidden_states=True,
132-
return_dict=True
133-
)
134+
input_ids=inputs, attention_mask=attn, output_hidden_states=True, return_dict=True
135+
)
134136
else:
135137
outputs = self.base_model_for_call(
136-
input_ids=inputs,
137-
attention_mask=attn,
138-
output_hidden_states=True,
139-
return_dict=True
138+
input_ids=inputs, attention_mask=attn, output_hidden_states=True, return_dict=True
140139
)
141-
140+
142141
if outputs.hidden_states:
143142
self.compute_bi(outputs.hidden_states)
144-
145-
def load_model_weights(model_folder_path: str) -> OrderedDict:
143+
144+
145+
def load_model_weights(model_folder_path: str) -> OrderedDict:
146146
print(f"Attempting to load model weights from FOLDER: '{model_folder_path}'...")
147147

148148
# 1. Ensure the path is a valid directory
@@ -156,7 +156,7 @@ def load_model_weights(model_folder_path: str) -> OrderedDict:
156156
if os.path.isfile(index_path):
157157
# Case A: Sharded model format detected (index file found)
158158
print("Sharded model format detected (index file found).")
159-
with open(index_path, 'r', encoding='utf-8') as f:
159+
with open(index_path, "r", encoding="utf-8") as f:
160160
index_data = json.load(f)
161161

162162
shard_files = sorted(list(set(index_data["weight_map"].values())))
@@ -190,9 +190,7 @@ def load_model_weights(model_folder_path: str) -> OrderedDict:
190190
"but no 'model_state.pdparams.index.json' to specify order."
191191
)
192192
else: # len(pdparams_files) == 0
193-
raise FileNotFoundError(
194-
f"No .pdparams files found in the directory '{model_folder_path}'."
195-
)
193+
raise FileNotFoundError(f"No .pdparams files found in the directory '{model_folder_path}'.")
196194

197195
return state_dict
198196

@@ -210,9 +208,9 @@ def prune_and_save_model_in_memory(
210208
"""
211209
Prunes and saves a model directly from the in-memory model object.
212210
"""
213-
print("="*50)
211+
print("=" * 50)
214212
print("PART 2: Starting In-Memory Model Pruning and Saving")
215-
print("="*50)
213+
print("=" * 50)
216214
os.makedirs(new_model_path, exist_ok=True)
217215

218216
# Step 1: Get state_dict directly from the in-memory model
@@ -276,41 +274,63 @@ def prune_and_save_model_in_memory(
276274

277275
print("\n🎉 Pruning process completed successfully!")
278276
print(f"Pruned model has been saved to '{new_model_path}'")
279-
280277

281278

282279
def main():
283280
parser = argparse.ArgumentParser(
284281
description="Calculate layer importance, prune, and save a new PaddlePaddle model."
285282
)
286-
parser.add_argument("--model_name_or_path", type=str, required=True, help="Path or HuggingFace name of the source PaddlePaddle model.")
287-
parser.add_argument("--output_model_path", type=str, required=True, help="Path to save the new, pruned model directory.")
288-
parser.add_argument("--layers_path", type=str, required=True, help="Dot-separated path to the layers list (e.g., 'llama.layers').")
289-
parser.add_argument("--n_prune_layers", type=int, required=True, help="The number of layers to identify and prune.")
290-
parser.add_argument("--dataset_name", type=str, default="emozilla/pg19", help="Name of the Hugging Face dataset for calibration. Default: 'emozilla/pg19'.")
291-
parser.add_argument("--dataset_split", type=str, default="validation", help="The split of the dataset to use. Default: 'validation'.")
283+
parser.add_argument(
284+
"--model_name_or_path",
285+
type=str,
286+
required=True,
287+
help="Path or HuggingFace name of the source PaddlePaddle model.",
288+
)
289+
parser.add_argument(
290+
"--output_model_path", type=str, required=True, help="Path to save the new, pruned model directory."
291+
)
292+
parser.add_argument(
293+
"--layers_path", type=str, required=True, help="Dot-separated path to the layers list (e.g., 'llama.layers')."
294+
)
295+
parser.add_argument(
296+
"--n_prune_layers", type=int, required=True, help="The number of layers to identify and prune."
297+
)
298+
parser.add_argument(
299+
"--dataset_name",
300+
type=str,
301+
default="emozilla/pg19",
302+
help="Name of the Hugging Face dataset for calibration. Default: 'emozilla/pg19'.",
303+
)
304+
parser.add_argument(
305+
"--dataset_split",
306+
type=str,
307+
default="validation",
308+
help="The split of the dataset to use. Default: 'validation'.",
309+
)
292310
args = parser.parse_args()
293311

294312
# --- PART 1: Calculate Layer Importance ---
295-
print("="*50)
313+
print("=" * 50)
296314
print("PART 1: Calculating Layer Importance")
297-
print("="*50)
315+
print("=" * 50)
298316
print(f"Loading '{args.dataset_split}' split from '{args.dataset_name}' dataset for calibration...")
299317
try:
300318
data = load_dataset(args.dataset_name, split=args.dataset_split)
301319
except Exception as e:
302320
print(f"Failed to load dataset. Error: {e}")
303-
print("Please ensure the dataset name and split are correct and you have internet access for Hugging Face datasets.")
321+
print(
322+
"Please ensure the dataset name and split are correct and you have internet access for Hugging Face datasets."
323+
)
304324
return
305-
325+
306326
dataloader = DataLoader(data, batch_size=1, shuffle=False)
307-
327+
308328
short_model = ShortGPT(model_name=args.model_name_or_path, layers_path=args.layers_path)
309-
329+
310330
for batch in tqdm(dataloader, desc="Evaluating Layer Importance"):
311-
if 'text' not in batch:
331+
if "text" not in batch:
312332
raise ValueError("Dataset must contain a 'text' column.")
313-
prompts = batch['text']
333+
prompts = batch["text"]
314334
short_model.eval_importance(prompts=prompts, model_name=args.model_name_or_path, stride=256)
315335

316336
prune_order = sorted(range(len(short_model.importances)), key=lambda i: short_model.importances[i])
@@ -327,8 +347,9 @@ def main():
327347
tokenizer=short_model.tokenizer,
328348
new_model_path=args.output_model_path,
329349
layers_to_delete=layers_to_delete,
330-
layers_path_str=args.layers_path
350+
layers_path_str=args.layers_path,
331351
)
332352

353+
333354
if __name__ == "__main__":
334-
main()
355+
main()

0 commit comments

Comments
 (0)