diff --git a/graph_net/sample_pass/merge_subgraph_sources.py b/graph_net/sample_pass/merge_subgraph_sources.py new file mode 100644 index 000000000..691cea95b --- /dev/null +++ b/graph_net/sample_pass/merge_subgraph_sources.py @@ -0,0 +1,70 @@ +from graph_net.sample_pass.sample_pass import SamplePass +import graph_net.subgraph_range_util as subgraph_range_util +from pathlib import Path +import json +from typing import Union + + +class MergeSubgraphSources(SamplePass): + def __init__(self, config=None): + super().__init__(config) + + def declare_config( + self, + model_path_prefix: str = "", + source_model_path_prefixes: list = None, + subgraph_sources_json_file_name: str = "subgraph_sources.json", + ): + pass + + def __call__(self, rel_model_path: str): + model_path_prefix = self.config.get("model_path_prefix", "") + target_model_path = ( + Path(model_path_prefix) / rel_model_path + if model_path_prefix + else Path(rel_model_path) + ) + + source_model_path_prefixes = self.config.get("source_model_path_prefixes") or [] + source_model_paths = [ + Path(prefix) / rel_model_path for prefix in source_model_path_prefixes + ] + + self.merge_sources_for_deduplication(target_model_path, source_model_paths) + print(f"Merged {len(source_model_paths)} sources into {target_model_path}") + + def merge_sources_for_deduplication( + self, + target_model_path: Union[str, Path], + source_model_paths: list[Union[str, Path]], + ) -> dict[str, list[tuple[int, int]]]: + merged_sources = self._load_sources(target_model_path) + for source_path in source_model_paths: + source_sources = self._load_sources(source_path) + if source_sources: + merged_sources = subgraph_range_util.merge_subgraph_ranges( + merged_sources, source_sources + ) + self._save_sources(target_model_path, merged_sources) + return merged_sources + + def _get_sources_file_path(self, model_path: Union[str, Path]) -> Path: + return Path(model_path) / self.config["subgraph_sources_json_file_name"] + + def _load_sources( + self, model_path: Union[str, Path] + ) -> dict[str, list[tuple[int, int]]]: + file_path = self._get_sources_file_path(model_path) + if not file_path.exists(): + return {} + with open(file_path, "r", encoding="utf-8") as f: + return json.load(f) + + def _save_sources( + self, + model_path: Union[str, Path], + sources: dict[str, list[tuple[int, int]]], + ) -> None: + file_path = self._get_sources_file_path(model_path) + with open(file_path, "w", encoding="utf-8") as f: + json.dump(sources, f, indent=4) diff --git a/graph_net/tools/deduplicated.py b/graph_net/tools/deduplicated.py index 3158c760e..b4cccf6c9 100644 --- a/graph_net/tools/deduplicated.py +++ b/graph_net/tools/deduplicated.py @@ -7,6 +7,9 @@ from collections import defaultdict from pathlib import Path from typing import Dict, List +from graph_net.sample_pass.merge_subgraph_sources import ( + MergeSubgraphSources, +) def collect_graph_hashs(samples_dir: Path) -> Dict[str, List[Path]]: @@ -26,13 +29,21 @@ def main(args): print(f"Copy samples: {args.samples_dir} -> {args.target_dir}") shutil.copytree(args.samples_dir, args.target_dir) graph_hash2model_paths = collect_graph_hashs(args.target_dir) + merge_sources_pass = MergeSubgraphSources() num_removed_samples = 0 for graph_hash, model_paths in graph_hash2model_paths.items(): - # Keep the first sample and move the rest. - for idx in range(1, len(model_paths)): - print(f"Remove {model_paths[idx]}") - shutil.rmtree(model_paths[idx]) - num_removed_samples += 1 + if len(model_paths) > 1: + # Keep the first sample and merge sources from all duplicates + target_path = model_paths[0] + duplicate_paths = model_paths[1:] + merge_sources_pass.merge_sources_for_deduplication( + target_path, duplicate_paths + ) + # Remove the duplicate samples + for dup_path in duplicate_paths: + print(f"Remove {dup_path}") + shutil.rmtree(dup_path) + num_removed_samples += 1 print( f"Totally {len(graph_hash2model_paths)} different graph_hashs, {num_removed_samples} samples are removed." )