Skip to content

Commit b7d188a

Browse files
committed
add start signature and end signature iinformation in subgraph model name
1 parent 8f5693e commit b7d188a

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

graph_net/torch/sample_passes/subgraph_generator.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,9 @@ def resume(self, rel_model_path: str):
7070
split_positions = self._get_split_positions(subgraph_ranges)
7171
rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph(
7272
gm,
73-
submodule_hook=self.get_naive_decomposer_extractor(rel_model_path),
73+
submodule_hook=self.get_naive_decomposer_extractor(
74+
rel_model_path, subgraph_ranges
75+
),
7476
split_positions=split_positions,
7577
subgraph_ranges=subgraph_ranges,
7678
group_head_and_tail=self.config.get("group_head_and_tail", False),
@@ -90,13 +92,15 @@ def _get_split_positions(self, subgraph_ranges: list[(int, int)]):
9092
split_positions = [position for pair in subgraph_ranges for position in pair]
9193
return sorted(set(split_positions))
9294

93-
def get_naive_decomposer_extractor(self, rel_model_path):
95+
def get_naive_decomposer_extractor(self, rel_model_path, subgraph_ranges):
9496
def fn(submodule, seq_no):
9597
return NaiveDecomposerExtractorModule(
9698
config=self.config,
9799
parent_graph_rel_model_path=rel_model_path,
98100
submodule=submodule,
99101
seq_no=seq_no,
102+
subgraph_start=subgraph_ranges[seq_no][0],
103+
subgraph_end=subgraph_ranges[seq_no][1],
100104
)
101105

102106
return fn
@@ -109,6 +113,8 @@ def __init__(
109113
parent_graph_rel_model_path: str,
110114
submodule: torch.nn.Module,
111115
seq_no: int,
116+
subgraph_start: int,
117+
subgraph_end: int,
112118
):
113119
super().__init__()
114120
self.config = config
@@ -120,7 +126,7 @@ def __init__(
120126
if self.seq_no is None:
121127
self.model_name = parent_graph_model_name
122128
else:
123-
submodule_name = f"{parent_graph_model_name}_{self.seq_no}"
129+
submodule_name = f"{parent_graph_model_name}_start{subgraph_start}_end{subgraph_end}_{self.seq_no}"
124130
self.model_name = submodule_name
125131
self.builtin_extractor = BuiltinGraphExtractor(
126132
name=submodule_name,

0 commit comments

Comments
 (0)