@@ -203,7 +203,7 @@ def analyze(self, model_paths_file: str, device: str) -> Dict[str, Dict]:
203203 )
204204
205205 current_idx = 0
206- split_points_set = set ()
206+ split_positions = set ()
207207 total_len = sum (token2len .get (t , 1 ) for t in seq_tokens )
208208
209209 for token_id in seq_tokens :
@@ -212,22 +212,22 @@ def analyze(self, model_paths_file: str, device: str) -> Dict[str, Dict]:
212212
213213 if is_pattern :
214214 if current_idx > 0 :
215- split_points_set .add (current_idx )
215+ split_positions .add (current_idx )
216216 end_idx = current_idx + length
217217 if end_idx < total_len :
218- split_points_set .add (end_idx )
218+ split_positions .add (end_idx )
219219
220220 current_idx += length
221221
222- sorted_splits = sorted (list (split_points_set ))
222+ sorted_splits = sorted (list (split_positions ))
223223
224224 self ._print_analysis (
225225 model_name , str (original_path ), sorted_splits , total_len , full_model_ops
226226 )
227227
228- results [model_name ] = {
229- "path " : str ( original_path ) ,
230- "split_points " : sorted_splits ,
228+ results [str ( original_path ) ] = {
229+ "model_name " : model_name ,
230+ "split_positions " : sorted_splits ,
231231 "total_length" : total_len ,
232232 }
233233
0 commit comments