|
| 1 | +from graph_net.sample_pass.sample_pass import SamplePass |
| 2 | +from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin |
| 3 | +from graph_net.torch.fx_graph_cache_util import ( |
| 4 | + parse_immutable_model_path_into_sole_graph_module, |
| 5 | +) |
| 6 | +from graph_net.torch.decompose_util import convert_to_submodules_graph |
| 7 | +from graph_net.torch.count_kernels_util import count_kernels |
| 8 | +from graph_net.torch.fx_graph_module_util import ( |
| 9 | + get_fx_graph_num_ops, |
| 10 | + get_torch_module_and_inputs, |
| 11 | +) |
| 12 | +from pathlib import Path |
| 13 | +import json |
| 14 | +import torch |
| 15 | + |
| 16 | + |
| 17 | +class CumSumNumKernelsGenerator(SamplePass, ResumableSamplePassMixin): |
| 18 | + def __init__(self, config): |
| 19 | + super().__init__(config) |
| 20 | + |
| 21 | + def declare_config( |
| 22 | + self, |
| 23 | + model_path_prefix: str, |
| 24 | + output_dir: str, |
| 25 | + resume: bool = False, |
| 26 | + limits_handled_models: int = None, |
| 27 | + output_json_file_name: str = "cumsum_num_kernels.json", |
| 28 | + ): |
| 29 | + pass |
| 30 | + |
| 31 | + def __call__(self, rel_model_path: str): |
| 32 | + self.resumable_handle_sample(rel_model_path) |
| 33 | + |
| 34 | + def sample_handled(self, rel_model_path: str) -> bool: |
| 35 | + file_name = self.config["output_json_file_name"] |
| 36 | + return self.naive_sample_handled(rel_model_path, search_file_name=file_name) |
| 37 | + |
| 38 | + def resume(self, rel_model_path: str): |
| 39 | + model_path = Path(self.config["model_path_prefix"]) / rel_model_path |
| 40 | + analyzer = CumsumNumKernelsAnalyzer(model_path) |
| 41 | + cumsum_num_kernels = analyzer.analyze() |
| 42 | + cumsum_num_kernels_json = json.dumps(cumsum_num_kernels, indent=4) |
| 43 | + output_dir_path = Path(self.config["output_dir"]) / rel_model_path |
| 44 | + (output_dir_path / self.config["output_json_file_name"]).write_text( |
| 45 | + cumsum_num_kernels_json |
| 46 | + ) |
| 47 | + |
| 48 | + |
| 49 | +class CumsumNumKernelsAnalyzer: |
| 50 | + def __init__(self, model_path: Path): |
| 51 | + self.model_path = model_path |
| 52 | + |
| 53 | + def analyze(self): |
| 54 | + triples = list(self._get_cumsum_num_kernels()) |
| 55 | + data = { |
| 56 | + "range_and_num_kernels": [ |
| 57 | + ((start, end), num_kernels) for start, end, num_kernels in triples |
| 58 | + ], |
| 59 | + } |
| 60 | + return data |
| 61 | + |
| 62 | + def _get_cumsum_num_kernels(self): |
| 63 | + model_path = str(self.model_path) |
| 64 | + module, inputs = get_torch_module_and_inputs(model_path, use_dummy_inputs=False) |
| 65 | + gm = parse_immutable_model_path_into_sole_graph_module(model_path) |
| 66 | + for start, end in self._get_ranges(gm): |
| 67 | + assert start == 0 |
| 68 | + num_kernels = self._get_num_kernels_if_submodule_compiled( |
| 69 | + graph_module=gm, |
| 70 | + nn_module=module, |
| 71 | + inputs=inputs, |
| 72 | + submodule_start=start, |
| 73 | + submodule_end=end, |
| 74 | + ) |
| 75 | + print(f"subgraph_range=[{start}, {end})\t{num_kernels=}") |
| 76 | + yield start, end, num_kernels |
| 77 | + |
| 78 | + def _get_num_kernels_if_submodule_compiled( |
| 79 | + self, graph_module, nn_module, inputs, submodule_start, submodule_end |
| 80 | + ): |
| 81 | + torch.cuda.empty_cache() |
| 82 | + rewrited_gm: torch.fx.GraphModule = convert_to_submodules_graph( |
| 83 | + graph_module, |
| 84 | + submodule_hook=lambda m, seq_no: torch.compile(m), |
| 85 | + split_positions=[submodule_start, submodule_end], |
| 86 | + subgraph_ranges=[(submodule_start, submodule_end)], |
| 87 | + group_head_and_tail=False, |
| 88 | + chain_style=False, |
| 89 | + ) |
| 90 | + _, num_kernels = count_kernels(rewrited_gm, inputs) |
| 91 | + return num_kernels |
| 92 | + |
| 93 | + def _get_ranges(self, gm): |
| 94 | + num_ops = get_fx_graph_num_ops(gm) |
| 95 | + for i in range(num_ops): |
| 96 | + cum_num_ops = i + 1 |
| 97 | + yield 0, cum_num_ops |
0 commit comments