diff --git a/graph_net/apply_sample_pass.py b/graph_net/apply_sample_pass.py index ae34b021d..ded070ff0 100644 --- a/graph_net/apply_sample_pass.py +++ b/graph_net/apply_sample_pass.py @@ -39,23 +39,32 @@ def _get_handler(args): def main(args): handler = _get_handler(args) if args.model_path is not None: + assert not hasattr(handler, "BEGIN") + assert not hasattr(handler, "END") handle_model_path(handler, args.model_path) elif args.use_subprocess: + assert not hasattr(handler, "BEGIN") + assert not hasattr(handler, "END") handle_model_path_list_in_subprocess(args) else: handle_model_path_list_in_current_process(handler, args) def handle_model_path_list_in_current_process(handler, args): - for model_path in _get_model_path_list(args): + rel_model_paths = list(_get_model_path_list(args)) + if hasattr(handler, "BEGIN"): + handler.BEGIN(rel_model_paths) + for rel_model_path in rel_model_paths: try: - handle_model_path(handler, model_path) + handle_model_path(handler, rel_model_path) except KeyboardInterrupt: print("KeyboardInterrupt") return except Exception: print("------------[apply_sample_pass failed]------------", flush=True) traceback.print_exc() + if hasattr(handler, "END"): + handler.END(rel_model_paths) def handle_model_path_list_in_subprocess(args): diff --git a/graph_net/bash_templates/apply_sample_pass_sh.txt b/graph_net/bash_templates/apply_sample_pass_sh.txt new file mode 100644 index 000000000..322176d43 --- /dev/null +++ b/graph_net/bash_templates/apply_sample_pass_sh.txt @@ -0,0 +1,16 @@ +#!/bin/bash + +GRAPH_NET_ROOT=$(python3 -c "import graph_net; import os; print(os.path.dirname(os.path.dirname(graph_net.__file__)))") + +python3 -m graph_net.apply_sample_pass \ + --model-path-list "customize_your_model_path_list" \ + --sample-pass-file-path "$GRAPH_NET_ROOT/graph_net/customize_your_sample_pass.py" \ + --sample-pass-class-name customize_your_class_name \ + --sample-pass-config $(base64 -w 0 < start + ] + + def get_range_idx2range_by_subgraph_ranges(): + assert subgraph_ranges is not None + num_nodes = len(submodules_body_nodes) + for i in range(len(subgraph_ranges)): + start, end = subgraph_ranges[i] + assert start >= 0 + assert start < end + assert end <= num_nodes + # check disjoint + assert i == 0 or start >= subgraph_ranges[i - 1][1], f"{i=}" + return subgraph_ranges + + range_idx2range = ( + get_range_idx2range_by_split_positions() + if chain_style + else get_range_idx2range_by_subgraph_ranges() + ) + range_idx2submodule_body_nodes = [ + submodules_body_nodes[start:end] for start, end in range_idx2range + ] + + def get_body_nodes(range_idx): + return range_idx2submodule_body_nodes[range_idx] + + def get_start_node_idx(range_idx): + start_node = get_body_nodes(range_idx)[0] + for i, node in enumerate(gm.graph.nodes): + if node == start_node: + return i + raise NotImplementedError("Dead code.") + + def get_end_node_idx(range_idx): + last_node = get_body_nodes(range_idx)[-1] + for i, node in enumerate(gm.graph.nodes): + if node == last_node: + return i + 1 + raise NotImplementedError("Dead code.") + + num_subgraphs = len(range_idx2submodule_body_nodes) + for range_idx in range(num_subgraphs): + start, end = range_idx2range[range_idx] + ( + submodule_input_nodes, + submodule_output_nodes, + identity_nodes, + ) = _get_submodule_inputs_and_outputs( + gm=gm, + start_node_idx=get_start_node_idx(range_idx), + end_node_idx=get_end_node_idx(range_idx), + chain_style=chain_style, + ) + yield start, end, submodule_input_nodes + + def convert_to_submodules_graph( gm: torch.fx.GraphModule, split_positions: list[int], diff --git a/graph_net/torch/sample_pass/shape_propagator.py b/graph_net/torch/sample_pass/shape_propagator.py new file mode 100644 index 000000000..1eed8fe93 --- /dev/null +++ b/graph_net/torch/sample_pass/shape_propagator.py @@ -0,0 +1,98 @@ +from graph_net.sample_pass.sample_pass import SamplePass +from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin +from graph_net.torch.fx_graph_cache_util import ( + parse_immutable_model_path_into_sole_graph_module, +) +from graph_net.torch.fx_graph_module_util import ( + get_torch_module_and_inputs, +) +from pathlib import Path +import json +import torch +from torch.fx.passes.shape_prop import ShapeProp + + +class ShapePropagator(SamplePass, ResumableSamplePassMixin): + def __init__(self, config): + super().__init__(config) + + def declare_config( + self, + model_path_prefix: str, + output_dir: str, + device: str = "auto", + output_json_file_name: str = "shape_prop.json", + shape_prop_json_key: str = "op_name_and_tensor_output_shape_list", + resume: bool = False, + limits_handled_models: int = None, + ): + pass + + def __call__(self, rel_model_path: str): + self.resumable_handle_sample(rel_model_path) + + def sample_handled(self, rel_model_path: str) -> bool: + file_name = self.config["output_json_file_name"] + return self.naive_sample_handled(rel_model_path, search_file_name=file_name) + + def resume(self, rel_model_path: str): + model_path = Path(self.config["model_path_prefix"]) / rel_model_path + device = self._choose_device(self.config["device"]) + shape_prop = FxGraphShapePropagator(model_path, device) + op_and_shapes = shape_prop.infer_op_name_and_tensor_output_shape_list() + json_obj = { + self.config["shape_prop_json_key"]: op_and_shapes, + } + op_and_shapes_json = json.dumps(json_obj, indent=4) + output_dir_path = Path(self.config["output_dir"]) / rel_model_path + output_dir_path.mkdir(parents=True, exist_ok=True) + output_file_path = output_dir_path / self.config["output_json_file_name"] + output_file_path.write_text(op_and_shapes_json) + + def _choose_device(self, device) -> str: + if device in ["cpu", "cuda"]: + return device + return "cuda" if torch.cuda.is_available() else "cpu" + + +class FxGraphShapePropagator: + def __init__(self, model_path: Path, device: str): + self.model_path = model_path + self.device = device + + def infer_op_name_and_tensor_output_shape_list(self): + data = [ + (self._get_op_name(node), self._get_tensor_output_shape(node)) + for node in self._shape_propagated_nodes() + ] + return data + + def _get_tensor_output_shape(self, node): + meta = node.meta.get("tensor_meta") + if meta is None: + return None + if not hasattr(meta, "shape"): + return None + if not isinstance(meta.shape, (list, tuple)): + return None + return meta.shape + + def _get_op_name(self, node): + if node.op == "call_method": + return f"Tensor.{node.target}" + elif node.op == "call_function": + return getattr(node.target, "__name__", str(node.target)) + else: + return node.op + + def _shape_propagated_nodes(self): + model_path = str(self.model_path) + module, inputs = get_torch_module_and_inputs( + model_path, use_dummy_inputs=False, device=self.device + ) + gm = parse_immutable_model_path_into_sole_graph_module( + model_path, device=self.device + ) + ShapeProp(gm).propagate(*inputs) + for node in gm.graph.nodes: + yield node diff --git a/graph_net/torch/sample_pass/subgraph_input_producer_indexes_generator.py b/graph_net/torch/sample_pass/subgraph_input_producer_indexes_generator.py new file mode 100644 index 000000000..e7d6a802f --- /dev/null +++ b/graph_net/torch/sample_pass/subgraph_input_producer_indexes_generator.py @@ -0,0 +1,112 @@ +from graph_net.sample_pass.sample_pass import SamplePass +from typing import Generator +from graph_net.sample_pass.resumable_sample_pass_mixin import ResumableSamplePassMixin +import os +from pathlib import Path +import torch +import json + +from graph_net.torch.decompose_util import gen_submodule_input_nodes +from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs +from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module + + +class SubgraphInputProducerIndexesGenerator(SamplePass, ResumableSamplePassMixin): + def __init__(self, config): + super().__init__(config) + + def declare_config( + self, + output_dir: str, + model_path_prefix: str, + subgraph_ranges_json_root: str, + subgraph_ranges_json_file_name: str, + subgraph_ranges_json_key: str, + subgraph_ranges_json_rel_model_path_key: str = "subgraph_rel_model_paths", + output_json_file_name: str = "subgraph_input_producer_indexes.json", + output_json_key: str = "input_producer_indexes", + output_json_subgraph_rel_model_path_key: str = "subgraph_rel_model_paths", + group_head_and_tail: bool = False, + chain_style: bool = False, + device: str = "auto", + resume: bool = False, + limits_handled_models: int = None, + ): + pass + + def __call__(self, rel_model_path: str): + self.resumable_handle_sample(rel_model_path) + + def sample_handled(self, rel_model_path: str) -> bool: + file_name = self.config["output_json_file_name"] + return self.naive_sample_handled(rel_model_path, search_file_name=file_name) + + def resume(self, rel_model_path: str): + subgraph_input_producer_indexes = self._get_subgraph_input_producer_indexes( + rel_model_path + ) + dst_model_path = Path(self.config["output_dir"]) / rel_model_path + dst_model_path.mkdir(parents=True, exist_ok=True) + json_str = json.dumps(subgraph_input_producer_indexes, indent=4) + (dst_model_path / self.config["output_json_file_name"]).write_text(json_str) + + def _get_subgraph_input_producer_indexes(self, rel_model_path): + model_path = os.path.join(self.config["model_path_prefix"], rel_model_path) + torch.cuda.empty_cache() + device = self._choose_device(self.config["device"]) + module, inputs = get_torch_module_and_inputs( + model_path, use_dummy_inputs=False, device=device + ) + gm = parse_sole_graph_module(module, inputs) + torch.cuda.empty_cache() + subgraph_info_json = self._get_subgraph_info_json(rel_model_path) + + def get_subgraph_input_producer_indexes_json_obj(): + subgraph_ranges = self._get_subgraph_ranges(subgraph_info_json) + triples: Generator[(int, int, torch.fx.Node)] = gen_submodule_input_nodes( + gm, + subgraph_ranges=subgraph_ranges, + group_head_and_tail=self.config.get("group_head_and_tail", False), + chain_style=self.config.get("chain_style", False), + ) + node2node_idx = dict((node, i) for i, node in enumerate(gm.graph.nodes)) + input_producer_indexes = [ + { + "range_start": start, + "range_end": end, + "input_producer_indexes": [node2node_idx[node] for node in nodes], + } + for start, end, nodes in triples + ] + return {self.config["output_json_key"]: input_producer_indexes} + + def get_subgraph_rel_model_paths_json_obj(): + return { + self.config[ + "output_json_subgraph_rel_model_path_key" + ]: self._get_subgraph_paths(subgraph_info_json) + } + + return { + **get_subgraph_input_producer_indexes_json_obj(), + **get_subgraph_rel_model_paths_json_obj(), + } + + def _get_subgraph_info_json(self, rel_model_path: str) -> dict[str, list]: + model_path = Path(self.config["subgraph_ranges_json_root"]) / rel_model_path + file_path = model_path / self.config["subgraph_ranges_json_file_name"] + json_str = file_path.read_text() + return json.loads(json_str) + + def _get_subgraph_ranges(self, subgraph_ranges_and_paths_json) -> list[(int, int)]: + key = self.config["subgraph_ranges_json_key"] + return subgraph_ranges_and_paths_json[key] + + def _get_subgraph_paths(self, subgraph_ranges_and_paths_json) -> list[str]: + key = self.config["subgraph_ranges_json_rel_model_path_key"] + return subgraph_ranges_and_paths_json[key] + + def _choose_device(self, device) -> str: + if device in ["cpu", "cuda"]: + return device + return "cuda" if torch.cuda.is_available() else "cpu"