Skip to content

Commit cbce64f

Browse files
committed
Allow to specify device for SubgraphGenerator.
1 parent 5afb480 commit cbce64f

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

graph_net/tools/generate_subgraph_dataset.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,7 @@ EOF
204204
"model_path_prefix": "${DEVICE_REWRITED_OUTPUT_DIR}",
205205
"output_dir": "$FUSIBLE_SUBGRAPH_SAMPLES_DIR",
206206
"subgraph_ranges_json_root": "$FUSIBLE_SUBGRAPH_RANGES_DIR",
207+
"device": "cuda",
207208
"resume": ${RESUME}
208209
}
209210
}

graph_net/torch/sample_passes/subgraph_generator.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ def declare_config(
3030
subgraph_ranges_json_key: str = "subgraph_ranges",
3131
group_head_and_tail: bool = False,
3232
chain_style: bool = False,
33+
device: str = "auto",
3334
resume: bool = False,
3435
limits_handled_models: int = None,
3536
):
@@ -63,7 +64,10 @@ def _has_enough_subgraphs(self, rel_model_path, num_subgraphs):
6364
def resume(self, rel_model_path: str):
6465
model_path = os.path.join(self.config["model_path_prefix"], rel_model_path)
6566
torch.cuda.empty_cache()
66-
module, inputs = get_torch_module_and_inputs(model_path, use_dummy_inputs=False)
67+
device = self._choose_device(self.config["device"])
68+
module, inputs = get_torch_module_and_inputs(
69+
model_path, use_dummy_inputs=False, device=device
70+
)
6771
gm = parse_sole_graph_module(module, inputs)
6872
torch.cuda.empty_cache()
6973
subgraph_ranges = self._get_subgraph_ranges(rel_model_path)
@@ -105,6 +109,11 @@ def fn(submodule, seq_no):
105109

106110
return fn
107111

112+
def _choose_device(self, device) -> str:
113+
if device in ["cpu", "cuda"]:
114+
return device
115+
return "cuda" if torch.cuda.is_available() else "cpu"
116+
108117

109118
class NaiveDecomposerExtractorModule(torch.nn.Module):
110119
def __init__(

0 commit comments

Comments
 (0)