Skip to content

Commit ca1dabc

Browse files
committed
Fix get_incorrect_models.
1 parent 4f0b203 commit ca1dabc

File tree

1 file changed

+9
-8
lines changed

1 file changed

+9
-8
lines changed

graph_net/subgraph_decompose_and_evaluation_step.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -163,14 +163,15 @@ def get_incorrect_models(self, decompose_method):
163163

164164
incorrect_models = []
165165
for model_name, model_record in self.model_name2record.items():
166-
for subgraph_path in model_record.subgraph_paths:
167-
(
168-
extracted_model_name,
169-
subgraph_idx,
170-
) = extract_model_name_and_subgraph_idx(subgraph_path)
171-
assert extracted_model_name == model_name
172-
if subgraph_idx in model_record.incorrect_subgraph_idxs:
173-
incorrect_models.append(subgraph_path)
166+
assert model_record.subgraph_paths
167+
model_path_prefix = os.path.dirname(model_record.subgraph_paths[0])
168+
for subgraph_idx in model_record.incorrect_subgraph_idxs:
169+
subgraph_path = os.path.join(
170+
model_path_prefix, f"{model_name}_{subgraph_idx}"
171+
)
172+
if subgraph_idx == 0:
173+
assert subgraph_path in model_record.subgraph_paths
174+
incorrect_models.append(subgraph_path)
174175
return incorrect_models
175176

176177
def collect_decomposed_subgraphs(self, decomposed_samples_dir):

0 commit comments

Comments
 (0)