Skip to content

Commit 2847ac9

Browse files
committed
Merge branch 'develop' of https://github.com/PaddlePaddle/GraphNet into subgraph
2 parents 0a93d2e + 7d7d2ca commit 2847ac9

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

graph_net/tools/generate_subgraph_dataset.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ function gen_fusible_subgraphs() {
199199
"model_path_prefix": "$2",
200200
"output_dir": "$4",
201201
"subgraph_ranges_json_root": "$3",
202+
"device": "cuda",
202203
"resume": ${RESUME}
203204
}
204205
}

graph_net/torch/sample_passes/subgraph_generator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def declare_config(
3030
subgraph_ranges_json_key: str = "subgraph_ranges",
3131
group_head_and_tail: bool = False,
3232
chain_style: bool = False,
33+
device: str = "auto",
3334
resume: bool = False,
3435
limits_handled_models: int = None,
3536
):
@@ -63,7 +64,10 @@ def _has_enough_subgraphs(self, rel_model_path, num_subgraphs):
6364
def resume(self, rel_model_path: str):
6465
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
6566
torch.cuda.empty_cache()
66-
module, inputs = get_torch_module_and_inputs(model_path, use_dummy_inputs=False)
67+
device = self._choose_device(self.config["device"])
68+
module, inputs = get_torch_module_and_inputs(
69+
model_path, use_dummy_inputs=False, device=device
70+
)
6771
gm = parse_sole_graph_module(module, inputs)
6872
torch.cuda.empty_cache()
6973
subgraph_ranges = self._get_subgraph_ranges(rel_model_path)
@@ -105,6 +109,11 @@ def fn(submodule, seq_no):
105109

106110
return fn
107111

112+
def _choose_device(self, device) -> str:
113+
if device in ["cpu", "cuda"]:
114+
return device
115+
return "cuda" if torch.cuda.is_available() else "cpu"
116+
108117

109118
class NaiveDecomposerExtractorModule(torch.nn.Module):
110119
def __init__(

0 commit comments

Comments
 (0)