Skip to content

Commit e20e93f

Browse files
authored
Add MergeSubgraphSources sample pass to merge subgraph_sources.json when duplicate (#504)
* Add MergeSubgraphSources sample pass * Modify to be used by apply_sample_pass
1 parent 3544cc1 commit e20e93f

File tree

2 files changed

+86
-5
lines changed

2 files changed

+86
-5
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
from graph_net.sample_pass.sample_pass import SamplePass
2+
import graph_net.subgraph_range_util as subgraph_range_util
3+
from pathlib import Path
4+
import json
5+
from typing import Union
6+
7+
8+
class MergeSubgraphSources(SamplePass):
9+
def __init__(self, config=None):
10+
super().__init__(config)
11+
12+
def declare_config(
13+
self,
14+
model_path_prefix: str = "",
15+
source_model_path_prefixes: list = None,
16+
subgraph_sources_json_file_name: str = "subgraph_sources.json",
17+
):
18+
pass
19+
20+
def __call__(self, rel_model_path: str):
21+
model_path_prefix = self.config.get("model_path_prefix", "")
22+
target_model_path = (
23+
Path(model_path_prefix) / rel_model_path
24+
if model_path_prefix
25+
else Path(rel_model_path)
26+
)
27+
28+
source_model_path_prefixes = self.config.get("source_model_path_prefixes") or []
29+
source_model_paths = [
30+
Path(prefix) / rel_model_path for prefix in source_model_path_prefixes
31+
]
32+
33+
self.merge_sources_for_deduplication(target_model_path, source_model_paths)
34+
print(f"Merged {len(source_model_paths)} sources into {target_model_path}")
35+
36+
def merge_sources_for_deduplication(
37+
self,
38+
target_model_path: Union[str, Path],
39+
source_model_paths: list[Union[str, Path]],
40+
) -> dict[str, list[tuple[int, int]]]:
41+
merged_sources = self._load_sources(target_model_path)
42+
for source_path in source_model_paths:
43+
source_sources = self._load_sources(source_path)
44+
if source_sources:
45+
merged_sources = subgraph_range_util.merge_subgraph_ranges(
46+
merged_sources, source_sources
47+
)
48+
self._save_sources(target_model_path, merged_sources)
49+
return merged_sources
50+
51+
def _get_sources_file_path(self, model_path: Union[str, Path]) -> Path:
52+
return Path(model_path) / self.config["subgraph_sources_json_file_name"]
53+
54+
def _load_sources(
55+
self, model_path: Union[str, Path]
56+
) -> dict[str, list[tuple[int, int]]]:
57+
file_path = self._get_sources_file_path(model_path)
58+
if not file_path.exists():
59+
return {}
60+
with open(file_path, "r", encoding="utf-8") as f:
61+
return json.load(f)
62+
63+
def _save_sources(
64+
self,
65+
model_path: Union[str, Path],
66+
sources: dict[str, list[tuple[int, int]]],
67+
) -> None:
68+
file_path = self._get_sources_file_path(model_path)
69+
with open(file_path, "w", encoding="utf-8") as f:
70+
json.dump(sources, f, indent=4)

graph_net/tools/deduplicated.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@
77
from collections import defaultdict
88
from pathlib import Path
99
from typing import Dict, List
10+
from graph_net.sample_pass.merge_subgraph_sources import (
11+
MergeSubgraphSources,
12+
)
1013

1114

1215
def collect_graph_hashs(samples_dir: Path) -> Dict[str, List[Path]]:
@@ -26,13 +29,21 @@ def main(args):
2629
print(f"Copy samples: {args.samples_dir} -> {args.target_dir}")
2730
shutil.copytree(args.samples_dir, args.target_dir)
2831
graph_hash2model_paths = collect_graph_hashs(args.target_dir)
32+
merge_sources_pass = MergeSubgraphSources()
2933
num_removed_samples = 0
3034
for graph_hash, model_paths in graph_hash2model_paths.items():
31-
# Keep the first sample and move the rest.
32-
for idx in range(1, len(model_paths)):
33-
print(f"Remove {model_paths[idx]}")
34-
shutil.rmtree(model_paths[idx])
35-
num_removed_samples += 1
35+
if len(model_paths) > 1:
36+
# Keep the first sample and merge sources from all duplicates
37+
target_path = model_paths[0]
38+
duplicate_paths = model_paths[1:]
39+
merge_sources_pass.merge_sources_for_deduplication(
40+
target_path, duplicate_paths
41+
)
42+
# Remove the duplicate samples
43+
for dup_path in duplicate_paths:
44+
print(f"Remove {dup_path}")
45+
shutil.rmtree(dup_path)
46+
num_removed_samples += 1
3647
print(
3748
f"Totally {len(graph_hash2model_paths)} different graph_hashs, {num_removed_samples} samples are removed."
3849
)

0 commit comments

Comments
 (0)