1
+ # Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
1
15
import argparse
2
16
import json
3
17
import math
4
18
import os
5
19
import re
6
20
from collections import OrderedDict
7
- from shutil import copyfile
8
- from typing import List , Optional
21
+ from typing import List
9
22
10
- import numpy as np
11
23
import paddle
12
24
from datasets import load_dataset
13
25
from paddle .io import DataLoader
14
- from paddlenlp .transformers import AutoModelForCausalLM , AutoTokenizer , NVEncodeModel
15
26
from tqdm import tqdm
16
27
28
+ from paddlenlp .transformers import AutoModelForCausalLM , AutoTokenizer , NVEncodeModel
29
+
30
+
17
31
# =====================================================================================
18
32
# 1. block_influence
19
33
# =====================================================================================
@@ -33,19 +47,21 @@ def block_influence(
33
47
norm_output = paddle .norm (output_hidden_state , p = 2 , axis = - 1 , keepdim = True )
34
48
35
49
sim = paddle .matmul (input_hidden_state , output_hidden_state , transpose_y = True ) / (norm_input * norm_output )
36
- sim = paddle .diag (sim ).astype (' float32' ).nan_to_num (nan = 0.5 )
50
+ sim = paddle .diag (sim ).astype (" float32" ).nan_to_num (nan = 0.5 )
37
51
38
52
if angular :
39
53
return paddle .acos (sim ) / math .pi
40
54
return 1 - sim
41
55
56
+
42
57
# =====================================================================================
43
- # 2. ShortGPT
58
+ # 2. ShortGPT
44
59
# =====================================================================================
45
60
class ShortGPT :
46
61
"""
47
62
A class to evaluate layer importance in LLMs using PaddlePaddle.
48
63
"""
64
+
49
65
def __init__ (self , model_name : str , layers_path : str ):
50
66
print (f"Loading tokenizer for '{ model_name } '..." )
51
67
self .tokenizer = AutoTokenizer .from_pretrained (model_name )
@@ -54,36 +70,30 @@ def __init__(self, model_name: str, layers_path: str):
54
70
print (f"Loading model '{ model_name } ' with PaddlePaddle backend..." )
55
71
if "NV-Embed" in model_name :
56
72
self .model = NVEncodeModel .from_pretrained (
57
- model_name ,
58
- tokenizer_path = model_name ,
59
- query_instruction = "" ,
60
- document_instruction = ""
61
- )
62
- else :
63
- self .model = AutoModelForCausalLM .from_pretrained (
64
- model_name ,
65
- dtype = paddle .float16
73
+ model_name , tokenizer_path = model_name , query_instruction = "" , document_instruction = ""
66
74
)
67
-
75
+ else :
76
+ self .model = AutoModelForCausalLM .from_pretrained (model_name , dtype = paddle .float16 )
77
+
68
78
self .model .eval ()
69
79
print ("Model loaded successfully for importance evaluation." )
70
-
80
+
71
81
try :
72
- path_parts = layers_path .split ('.' ) # e.g., 'llama.layers' -> ['llama', 'layers']
82
+ path_parts = layers_path .split ("." ) # e.g., 'llama.layers' -> ['llama', 'layers']
73
83
74
84
self .base_model_for_call = self .model
75
85
# 遍历路径中除了最后 'layers' 之外的部分 (e.g., 'llama')
76
86
for part in path_parts [:- 1 ]:
77
87
self .base_model_for_call = getattr (self .base_model_for_call , part )
78
-
88
+
79
89
# 从基础模型中获取 'layers' 列表
80
90
self .layers = getattr (self .base_model_for_call , path_parts [- 1 ])
81
91
print (f"Successfully located base model for evaluation call: { type (self .base_model_for_call )} " )
82
92
print (f"Successfully located { len (self .layers )} layers." )
83
93
84
94
except AttributeError :
85
95
raise AttributeError (f"Could not find layers at path '{ layers_path } ' in the model architecture." )
86
-
96
+
87
97
self .importances = [0.0 for _ in self .layers ]
88
98
89
99
def compute_bi (self , hiddens : List [paddle .Tensor ]):
@@ -95,20 +105,15 @@ def compute_bi(self, hiddens: List[paddle.Tensor]):
95
105
layer_index = i
96
106
if layer_index < len (self .importances ):
97
107
in_hidden = hiddens [i ]
98
- out_hidden = hiddens [i + n ]
99
- self .importances [layer_index ] += block_influence (
100
- in_hidden ,
101
- out_hidden
102
- ).sum ().item ()
108
+ out_hidden = hiddens [i + n ]
109
+ self .importances [layer_index ] += block_influence (in_hidden , out_hidden ).sum ().item ()
103
110
104
111
@paddle .no_grad ()
105
112
def eval_importance (self , prompts : List [str ], model_name : str , stride : int = 256 ):
106
113
"""
107
114
Evaluates the importance of model layers on given prompts.
108
115
"""
109
- prompt_tokens = self .tokenizer (
110
- prompts , padding = True , return_attention_mask = True , return_tensors = 'pd'
111
- )
116
+ prompt_tokens = self .tokenizer (prompts , padding = True , return_attention_mask = True , return_tensors = "pd" )
112
117
input_ids = prompt_tokens .input_ids
113
118
attn_mask = prompt_tokens .attention_mask
114
119
@@ -117,32 +122,27 @@ def eval_importance(self, prompts: List[str], model_name: str, stride: int = 256
117
122
for start in range (0 , max_prompt_len , stride ):
118
123
seq_ids = (attn_mask .sum (axis = - 1 ) > start ).nonzero ().squeeze ()
119
124
seq_ids = seq_ids .unsqueeze (0 ) if seq_ids .ndim == 0 else seq_ids
120
-
125
+
121
126
if seq_ids .shape [0 ] == 0 :
122
127
continue
123
128
124
- inputs = input_ids [seq_ids , start : start + stride ]
125
- attn = attn_mask [seq_ids , start : start + stride ]
129
+ inputs = input_ids [seq_ids , start : start + stride ]
130
+ attn = attn_mask [seq_ids , start : start + stride ]
126
131
127
132
if "NV-Embed" in model_name :
128
133
outputs = self .base_model_for_call .m_forward (
129
- input_ids = inputs ,
130
- attention_mask = attn ,
131
- output_hidden_states = True ,
132
- return_dict = True
133
- )
134
+ input_ids = inputs , attention_mask = attn , output_hidden_states = True , return_dict = True
135
+ )
134
136
else :
135
137
outputs = self .base_model_for_call (
136
- input_ids = inputs ,
137
- attention_mask = attn ,
138
- output_hidden_states = True ,
139
- return_dict = True
138
+ input_ids = inputs , attention_mask = attn , output_hidden_states = True , return_dict = True
140
139
)
141
-
140
+
142
141
if outputs .hidden_states :
143
142
self .compute_bi (outputs .hidden_states )
144
-
145
- def load_model_weights (model_folder_path : str ) -> OrderedDict :
143
+
144
+
145
+ def load_model_weights (model_folder_path : str ) -> OrderedDict :
146
146
print (f"Attempting to load model weights from FOLDER: '{ model_folder_path } '..." )
147
147
148
148
# 1. Ensure the path is a valid directory
@@ -156,7 +156,7 @@ def load_model_weights(model_folder_path: str) -> OrderedDict:
156
156
if os .path .isfile (index_path ):
157
157
# Case A: Sharded model format detected (index file found)
158
158
print ("Sharded model format detected (index file found)." )
159
- with open (index_path , 'r' , encoding = ' utf-8' ) as f :
159
+ with open (index_path , "r" , encoding = " utf-8" ) as f :
160
160
index_data = json .load (f )
161
161
162
162
shard_files = sorted (list (set (index_data ["weight_map" ].values ())))
@@ -190,9 +190,7 @@ def load_model_weights(model_folder_path: str) -> OrderedDict:
190
190
"but no 'model_state.pdparams.index.json' to specify order."
191
191
)
192
192
else : # len(pdparams_files) == 0
193
- raise FileNotFoundError (
194
- f"No .pdparams files found in the directory '{ model_folder_path } '."
195
- )
193
+ raise FileNotFoundError (f"No .pdparams files found in the directory '{ model_folder_path } '." )
196
194
197
195
return state_dict
198
196
@@ -210,9 +208,9 @@ def prune_and_save_model_in_memory(
210
208
"""
211
209
Prunes and saves a model directly from the in-memory model object.
212
210
"""
213
- print ("=" * 50 )
211
+ print ("=" * 50 )
214
212
print ("PART 2: Starting In-Memory Model Pruning and Saving" )
215
- print ("=" * 50 )
213
+ print ("=" * 50 )
216
214
os .makedirs (new_model_path , exist_ok = True )
217
215
218
216
# Step 1: Get state_dict directly from the in-memory model
@@ -276,41 +274,63 @@ def prune_and_save_model_in_memory(
276
274
277
275
print ("\n 🎉 Pruning process completed successfully!" )
278
276
print (f"Pruned model has been saved to '{ new_model_path } '" )
279
-
280
277
281
278
282
279
def main ():
283
280
parser = argparse .ArgumentParser (
284
281
description = "Calculate layer importance, prune, and save a new PaddlePaddle model."
285
282
)
286
- parser .add_argument ("--model_name_or_path" , type = str , required = True , help = "Path or HuggingFace name of the source PaddlePaddle model." )
287
- parser .add_argument ("--output_model_path" , type = str , required = True , help = "Path to save the new, pruned model directory." )
288
- parser .add_argument ("--layers_path" , type = str , required = True , help = "Dot-separated path to the layers list (e.g., 'llama.layers')." )
289
- parser .add_argument ("--n_prune_layers" , type = int , required = True , help = "The number of layers to identify and prune." )
290
- parser .add_argument ("--dataset_name" , type = str , default = "emozilla/pg19" , help = "Name of the Hugging Face dataset for calibration. Default: 'emozilla/pg19'." )
291
- parser .add_argument ("--dataset_split" , type = str , default = "validation" , help = "The split of the dataset to use. Default: 'validation'." )
283
+ parser .add_argument (
284
+ "--model_name_or_path" ,
285
+ type = str ,
286
+ required = True ,
287
+ help = "Path or HuggingFace name of the source PaddlePaddle model." ,
288
+ )
289
+ parser .add_argument (
290
+ "--output_model_path" , type = str , required = True , help = "Path to save the new, pruned model directory."
291
+ )
292
+ parser .add_argument (
293
+ "--layers_path" , type = str , required = True , help = "Dot-separated path to the layers list (e.g., 'llama.layers')."
294
+ )
295
+ parser .add_argument (
296
+ "--n_prune_layers" , type = int , required = True , help = "The number of layers to identify and prune."
297
+ )
298
+ parser .add_argument (
299
+ "--dataset_name" ,
300
+ type = str ,
301
+ default = "emozilla/pg19" ,
302
+ help = "Name of the Hugging Face dataset for calibration. Default: 'emozilla/pg19'." ,
303
+ )
304
+ parser .add_argument (
305
+ "--dataset_split" ,
306
+ type = str ,
307
+ default = "validation" ,
308
+ help = "The split of the dataset to use. Default: 'validation'." ,
309
+ )
292
310
args = parser .parse_args ()
293
311
294
312
# --- PART 1: Calculate Layer Importance ---
295
- print ("=" * 50 )
313
+ print ("=" * 50 )
296
314
print ("PART 1: Calculating Layer Importance" )
297
- print ("=" * 50 )
315
+ print ("=" * 50 )
298
316
print (f"Loading '{ args .dataset_split } ' split from '{ args .dataset_name } ' dataset for calibration..." )
299
317
try :
300
318
data = load_dataset (args .dataset_name , split = args .dataset_split )
301
319
except Exception as e :
302
320
print (f"Failed to load dataset. Error: { e } " )
303
- print ("Please ensure the dataset name and split are correct and you have internet access for Hugging Face datasets." )
321
+ print (
322
+ "Please ensure the dataset name and split are correct and you have internet access for Hugging Face datasets."
323
+ )
304
324
return
305
-
325
+
306
326
dataloader = DataLoader (data , batch_size = 1 , shuffle = False )
307
-
327
+
308
328
short_model = ShortGPT (model_name = args .model_name_or_path , layers_path = args .layers_path )
309
-
329
+
310
330
for batch in tqdm (dataloader , desc = "Evaluating Layer Importance" ):
311
- if ' text' not in batch :
331
+ if " text" not in batch :
312
332
raise ValueError ("Dataset must contain a 'text' column." )
313
- prompts = batch [' text' ]
333
+ prompts = batch [" text" ]
314
334
short_model .eval_importance (prompts = prompts , model_name = args .model_name_or_path , stride = 256 )
315
335
316
336
prune_order = sorted (range (len (short_model .importances )), key = lambda i : short_model .importances [i ])
@@ -327,8 +347,9 @@ def main():
327
347
tokenizer = short_model .tokenizer ,
328
348
new_model_path = args .output_model_path ,
329
349
layers_to_delete = layers_to_delete ,
330
- layers_path_str = args .layers_path
350
+ layers_path_str = args .layers_path ,
331
351
)
332
352
353
+
333
354
if __name__ == "__main__" :
334
- main ()
355
+ main ()
0 commit comments