Skip to content

Commit 460dc06

Browse files
committed
change the key of split_results.json from model_name to model_path
1 parent 92afb85 commit 460dc06

File tree

2 files changed

+8
-10
lines changed

2 files changed

+8
-10
lines changed

graph_net/torch/graph_decomposer.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -206,9 +206,7 @@ def _make_config(
206206
def __call__(self, rel_model_path):
207207
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
208208
split_results = load_json(self.config["split_results_path"])
209-
split_positions = split_results[os.path.basename(rel_model_path)][
210-
"split_points"
211-
]
209+
split_positions = split_results[rel_model_path]["split_positions"]
212210
config = {
213211
"split_positions": split_positions,
214212
"group_head_and_tail": self.config.get("group_head_and_tail", False),

graph_net/torch/typical_sequence_split_points.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,7 @@ def analyze(self, model_paths_file: str, device: str) -> Dict[str, Dict]:
203203
)
204204

205205
current_idx = 0
206-
split_points_set = set()
206+
split_positions = set()
207207
total_len = sum(token2len.get(t, 1) for t in seq_tokens)
208208

209209
for token_id in seq_tokens:
@@ -212,22 +212,22 @@ def analyze(self, model_paths_file: str, device: str) -> Dict[str, Dict]:
212212

213213
if is_pattern:
214214
if current_idx > 0:
215-
split_points_set.add(current_idx)
215+
split_positions.add(current_idx)
216216
end_idx = current_idx + length
217217
if end_idx < total_len:
218-
split_points_set.add(end_idx)
218+
split_positions.add(end_idx)
219219

220220
current_idx += length
221221

222-
sorted_splits = sorted(list(split_points_set))
222+
sorted_splits = sorted(list(split_positions))
223223

224224
self._print_analysis(
225225
model_name, str(original_path), sorted_splits, total_len, full_model_ops
226226
)
227227

228-
results[model_name] = {
229-
"path": str(original_path),
230-
"split_points": sorted_splits,
228+
results[str(original_path)] = {
229+
"model_name": model_name,
230+
"split_positions": sorted_splits,
231231
"total_length": total_len,
232232
}
233233

0 commit comments

Comments
 (0)