Skip to content

Commit 02f31da

Browse files
committed
Support initilizating input tensors to specifed device.
1 parent 023f21e commit 02f31da

File tree

3 files changed

+26
-7
lines changed

3 files changed

+26
-7
lines changed

graph_net/torch/fx_graph_cache_util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@
44
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
55

66

7-
def parse_immutable_model_path_into_sole_graph_module(model_path):
7+
def parse_immutable_model_path_into_sole_graph_module(model_path, device=None):
88
model_path = os.path.realpath(model_path)
99
if model_path not in g_model_path2graph_module:
10-
module, inputs = get_torch_module_and_inputs(model_path)
10+
module, inputs = get_torch_module_and_inputs(model_path, device=device)
1111
g_model_path2graph_module[model_path] = parse_sole_graph_module(module, inputs)
1212
return copy.deepcopy(g_model_path2graph_module[model_path])
1313

graph_net/torch/fx_graph_module_util.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@ def get_num_ops(node):
1212
return sum(map(get_num_ops, fx_graph_module.graph.nodes))
1313

1414

15-
def get_torch_module_and_inputs(model_path, use_dummy_inputs=True):
15+
def get_torch_module_and_inputs(model_path, use_dummy_inputs=True, device=None):
1616
module = _get_torch_module(model_path)
1717
tensor_metas = _get_tensor_metas(model_path)
1818
inputs = _create_inputs_by_metas(module, tensor_metas, use_dummy_inputs)
19+
if device:
20+
inputs = [tensor.to(device=device) for tensor in inputs]
1921
return module, inputs
2022

2123

graph_net/torch/sample_passes/cumsum_num_kernels_generator.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def declare_config(
2424
model_path_prefix: str,
2525
output_dir: str,
2626
resume: bool = False,
27+
device: str = "auto",
2728
start_offset_in_original_graph: int = 0,
2829
limits_handled_models: int = None,
2930
output_json_file_name: str = "cumsum_num_kernels.json",
@@ -39,19 +40,30 @@ def sample_handled(self, rel_model_path: str) -> bool:
3940

4041
def resume(self, rel_model_path: str):
4142
model_path = Path(self.config["model_path_prefix"]) / rel_model_path
43+
device = self._choose_device(self.config["device"])
4244
start_offset_in_original_graph = self.config["start_offset_in_original_graph"]
43-
analyzer = CumsumNumKernelsAnalyzer(model_path, start_offset_in_original_graph)
45+
analyzer = CumsumNumKernelsAnalyzer(
46+
model_path, device, start_offset_in_original_graph
47+
)
4448
cumsum_num_kernels = analyzer.analyze()
4549
cumsum_num_kernels_json = json.dumps(cumsum_num_kernels, indent=4)
4650
output_dir_path = Path(self.config["output_dir"]) / rel_model_path
4751
output_dir_path.mkdir(parents=True, exist_ok=True)
4852
output_file_path = output_dir_path / self.config["output_json_file_name"]
4953
output_file_path.write_text(cumsum_num_kernels_json)
5054

55+
def _choose_device(self, device) -> str:
56+
if device in ["cpu", "cuda"]:
57+
return device
58+
return "cuda" if torch.cuda.is_available() else "cpu"
59+
5160

5261
class CumsumNumKernelsAnalyzer:
53-
def __init__(self, model_path: Path, start_offset_in_original_graph: int):
62+
def __init__(
63+
self, model_path: Path, device: str, start_offset_in_original_graph: int
64+
):
5465
self.model_path = model_path
66+
self.device = device
5567
self.start_offset_in_original_graph = start_offset_in_original_graph
5668

5769
def analyze(self):
@@ -67,8 +79,12 @@ def analyze(self):
6779

6880
def _get_cumsum_num_kernels(self):
6981
model_path = str(self.model_path)
70-
module, inputs = get_torch_module_and_inputs(model_path, use_dummy_inputs=False)
71-
gm = parse_immutable_model_path_into_sole_graph_module(model_path)
82+
module, inputs = get_torch_module_and_inputs(
83+
model_path, use_dummy_inputs=False, device=self.device
84+
)
85+
gm = parse_immutable_model_path_into_sole_graph_module(
86+
model_path, device=self.device
87+
)
7288
for start, end in self._get_ranges(gm):
7389
assert start == 0
7490
num_kernels = self._get_num_kernels_if_submodule_compiled(
@@ -97,6 +113,7 @@ def compile_and_count_num_kernels(m, seq_no):
97113
group_head_and_tail=False,
98114
chain_style=False,
99115
)
116+
100117
rewrited_gm(*inputs)
101118
assert mut_opt_num_kernels.is_some()
102119
return mut_opt_num_kernels.unwrap()

0 commit comments

Comments
 (0)