|
| 1 | +import os |
| 2 | +import torch |
| 3 | +import json |
| 4 | +import shutil |
| 5 | +from typing import Union, Callable |
| 6 | +from graph_net.torch import utils |
| 7 | +from graph_net.torch.decompose_util import convert_to_submodules_graph |
| 8 | +from graph_net.torch.extractor import GraphExtractor as BuiltinGraphExtractor |
| 9 | + |
| 10 | + |
| 11 | +class GraphExtractor: |
| 12 | + def __init__( |
| 13 | + self, name, dynamic, mut_graph_codes=None, placeholder_auto_rename=False |
| 14 | + ): |
| 15 | + self.subgraph_counter = 0 |
| 16 | + self.name = name |
| 17 | + self.dynamic = dynamic |
| 18 | + self.mut_graph_codes = mut_graph_codes |
| 19 | + self.placeholder_auto_rename = placeholder_auto_rename |
| 20 | + self.workspace_path = os.environ.get("GRAPH_NET_EXTRACT_WORKSPACE") |
| 21 | + if not self.workspace_path: |
| 22 | + raise EnvironmentError( |
| 23 | + "Environment variable 'GRAPH_NET_EXTRACT_WORKSPACE' is not set." |
| 24 | + ) |
| 25 | + split_pos_str = os.environ.get("GRAPH_NET_NAIVE_DECOMPOSER_SPLIT_POS") |
| 26 | + if split_pos_str is None: |
| 27 | + raise EnvironmentError( |
| 28 | + "Environment variable 'GRAPH_NET_NAIVE_DECOMPOSER_SPLIT_POS' is not set." |
| 29 | + ) |
| 30 | + self.split_positions = [int(pos) for pos in split_pos_str.split(",")] |
| 31 | + |
| 32 | + def __call__(self, gm: torch.fx.GraphModule, sample_inputs): |
| 33 | + return convert_to_submodules_graph( |
| 34 | + gm, |
| 35 | + split_positions=self.split_positions, |
| 36 | + submodule_hook=self.get_naive_decomposer_extractor, |
| 37 | + group_head_and_tail=False, |
| 38 | + ) |
| 39 | + |
| 40 | + def get_naive_decomposer_extractor(self, submodule, seq_no): |
| 41 | + return NaiveDecomposerExtractor(self, submodule, seq_no) |
| 42 | + |
| 43 | + |
| 44 | +class NaiveDecomposerExtractor(torch.nn.Module): |
| 45 | + def __init__(self, parent_graph_extractor, submodule, seq_no): |
| 46 | + super().__init__() |
| 47 | + self.parent_graph_extractor = parent_graph_extractor |
| 48 | + self.submodule = submodule |
| 49 | + self.seq_no = seq_no |
| 50 | + self.extracted = False |
| 51 | + name = f"{parent_graph_extractor.name}_{self.seq_no}" |
| 52 | + self.builtin_extractor = BuiltinGraphExtractor( |
| 53 | + name=name, |
| 54 | + dynamic=False, |
| 55 | + mut_graph_codes=[], |
| 56 | + placeholder_auto_rename=parent_graph_extractor.placeholder_auto_rename, |
| 57 | + ) |
| 58 | + |
| 59 | + def forward(self, *args): |
| 60 | + if not self.extracted: |
| 61 | + self.builtin_extractor(self.submodule, args) |
| 62 | + self.extracted = True |
| 63 | + return self.submodule(*args) |
0 commit comments