Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions graph_net/sample_pass/merge_subgraph_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
from graph_net.sample_pass.sample_pass import SamplePass
import graph_net.subgraph_range_util as subgraph_range_util
from pathlib import Path
import json
from typing import Union


class MergeSubgraphSources(SamplePass):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sample_pass目录下一般都能被apply_sample_pass直接调用。但咱们这里不完整。可以把代码放到tools目录下。

def __init__(self, config=None):
super().__init__(config)

def declare_config(
self,
model_path_prefix: str = "",
source_model_path_prefixes: list = None,
subgraph_sources_json_file_name: str = "subgraph_sources.json",
):
pass

def __call__(self, rel_model_path: str):
model_path_prefix = self.config.get("model_path_prefix", "")
target_model_path = (
Path(model_path_prefix) / rel_model_path
if model_path_prefix
else Path(rel_model_path)
)

source_model_path_prefixes = self.config.get("source_model_path_prefixes") or []
source_model_paths = [
Path(prefix) / rel_model_path for prefix in source_model_path_prefixes
]

self.merge_sources_for_deduplication(target_model_path, source_model_paths)
print(f"Merged {len(source_model_paths)} sources into {target_model_path}")

def merge_sources_for_deduplication(
self,
target_model_path: Union[str, Path],
source_model_paths: list[Union[str, Path]],
) -> dict[str, list[tuple[int, int]]]:
merged_sources = self._load_sources(target_model_path)
for source_path in source_model_paths:
source_sources = self._load_sources(source_path)
if source_sources:
merged_sources = subgraph_range_util.merge_subgraph_ranges(
merged_sources, source_sources
)
self._save_sources(target_model_path, merged_sources)
return merged_sources

def _get_sources_file_path(self, model_path: Union[str, Path]) -> Path:
return Path(model_path) / self.config["subgraph_sources_json_file_name"]

def _load_sources(
self, model_path: Union[str, Path]
) -> dict[str, list[tuple[int, int]]]:
file_path = self._get_sources_file_path(model_path)
if not file_path.exists():
return {}
with open(file_path, "r", encoding="utf-8") as f:
return json.load(f)

def _save_sources(
self,
model_path: Union[str, Path],
sources: dict[str, list[tuple[int, int]]],
) -> None:
file_path = self._get_sources_file_path(model_path)
with open(file_path, "w", encoding="utf-8") as f:
json.dump(sources, f, indent=4)
21 changes: 16 additions & 5 deletions graph_net/tools/deduplicated.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from collections import defaultdict
from pathlib import Path
from typing import Dict, List
from graph_net.sample_pass.merge_subgraph_sources import (
MergeSubgraphSources,
)


def collect_graph_hashs(samples_dir: Path) -> Dict[str, List[Path]]:
Expand All @@ -26,13 +29,21 @@ def main(args):
print(f"Copy samples: {args.samples_dir} -> {args.target_dir}")
shutil.copytree(args.samples_dir, args.target_dir)
graph_hash2model_paths = collect_graph_hashs(args.target_dir)
merge_sources_pass = MergeSubgraphSources()
num_removed_samples = 0
for graph_hash, model_paths in graph_hash2model_paths.items():
# Keep the first sample and move the rest.
for idx in range(1, len(model_paths)):
print(f"Remove {model_paths[idx]}")
shutil.rmtree(model_paths[idx])
num_removed_samples += 1
if len(model_paths) > 1:
# Keep the first sample and merge sources from all duplicates
target_path = model_paths[0]
duplicate_paths = model_paths[1:]
merge_sources_pass.merge_sources_for_deduplication(
target_path, duplicate_paths
)
# Remove the duplicate samples
for dup_path in duplicate_paths:
print(f"Remove {dup_path}")
shutil.rmtree(dup_path)
num_removed_samples += 1
print(
f"Totally {len(graph_hash2model_paths)} different graph_hashs, {num_removed_samples} samples are removed."
)
Expand Down