|
5 | 5 | from dataclasses import dataclass |
6 | 6 |
|
7 | 7 |
|
8 | | -def fold_range_to_submodule( |
| 8 | +def convert_to_submodules_graph( |
9 | 9 | original_gm: torch.fx.GraphModule, |
10 | | - start_node_idx: int, |
11 | | - end_node_idx: int, |
| 10 | + split_positions: list[int], |
12 | 11 | submodule_hook=None, |
13 | | - submodule_name="extraced_submodule", |
| 12 | + submodule_name_prefix="extraced_submodule", |
| 13 | + group_head_and_tail=True, |
14 | 14 | ): |
15 | 15 | original_gm = copy.deepcopy(original_gm) |
16 | | - submodule_body_nodes = list(original_gm.graph.nodes)[start_node_idx:end_node_idx] |
17 | | - |
18 | | - def get_body_nodes(): |
19 | | - return submodule_body_nodes |
20 | | - |
21 | | - assert len(get_body_nodes()) > 0 |
22 | | - |
23 | | - for idx, original_node in enumerate(get_body_nodes()): |
24 | | - assert original_node.op not in { |
| 16 | + num_placeholders = len( |
| 17 | + [node for node in original_gm.graph.nodes if node.op == "placeholder"] |
| 18 | + ) |
| 19 | + submodules_body_nodes = [ |
| 20 | + node |
| 21 | + for node in original_gm.graph.nodes |
| 22 | + if node.op |
| 23 | + not in { |
25 | 24 | "placeholder", |
26 | 25 | "output", |
27 | | - }, f"{idx=}, {original_node.op=}" |
28 | | - |
29 | | - submodule_input_nodes, submodule_output_nodes = _get_submodule_inputs_and_outputs( |
30 | | - original_gm=original_gm, |
31 | | - start_node_idx=start_node_idx, |
32 | | - end_node_idx=end_node_idx, |
| 26 | + } |
| 27 | + ] |
| 28 | + split_positions = ( |
| 29 | + [0, *split_positions, len(submodules_body_nodes)] |
| 30 | + if group_head_and_tail |
| 31 | + else split_positions |
33 | 32 | ) |
| 33 | + split_positions = [ |
| 34 | + max(0, min(pos, len(submodules_body_nodes))) for pos in split_positions |
| 35 | + ] |
| 36 | + submodule_ranges = [ |
| 37 | + (start, end) |
| 38 | + for i in range(len(split_positions) - 1) |
| 39 | + for start in [split_positions[i]] |
| 40 | + for end in [split_positions[i + 1]] |
| 41 | + if end > start |
| 42 | + ] |
34 | 43 |
|
35 | | - def get_input_nodes(): |
36 | | - return submodule_input_nodes |
37 | | - |
38 | | - def get_output_nodes(): |
39 | | - return submodule_output_nodes |
| 44 | + def get_body_nodes(range_idx): |
| 45 | + start, end = submodule_ranges[range_idx] |
| 46 | + return submodules_body_nodes[start:end] |
40 | 47 |
|
41 | 48 | def get_name2sub_submodule(): |
42 | | - used_module_names = set() |
43 | | - for node in get_body_nodes(): |
44 | | - if node.op == "call_module": |
45 | | - used_module_names.add(node.target) |
| 49 | + used_module_names = set( |
| 50 | + [node.target for node in submodules_body_nodes if node.op == "call_module"] |
| 51 | + ) |
46 | 52 | return { |
47 | 53 | name: module |
48 | 54 | for name, module in original_gm.named_modules() |
49 | 55 | if name in used_module_names |
50 | 56 | } |
51 | 57 |
|
52 | | - new_graph = torch.fx.Graph() |
53 | | - # Create a mapping for nodes from original graph to new graph |
54 | | - node_map = {} |
55 | | - |
56 | | - # Add placeholder nodes for inputs |
57 | | - for original_node in get_input_nodes(): |
58 | | - new_node = new_graph.placeholder(original_node.name) |
59 | | - node_map[original_node] = new_node |
60 | | - |
61 | | - # Copy body nodes |
62 | | - for original_node in get_body_nodes(): |
63 | | - print(original_node) |
64 | | - new_node = new_graph.node_copy(original_node, lambda x: node_map[x]) |
65 | | - node_map[original_node] = new_node |
66 | | - |
67 | | - # Add output nodes |
68 | | - output_args = [] |
69 | | - for original_node in get_output_nodes(): |
70 | | - output_args.append(node_map[original_node]) |
71 | | - new_graph.output(tuple(output_args)) |
72 | | - |
73 | | - # Create the new GraphModule |
74 | | - # This assumes no submodules are being extracted, or they are handled separately |
75 | | - new_sub_module = torch.fx.GraphModule(get_name2sub_submodule(), new_graph) |
76 | | - if submodule_hook is not None: |
77 | | - new_sub_module = submodule_hook(new_sub_module) |
78 | | - # Replace with submodule node |
79 | | - original_gm.add_submodule(submodule_name, new_sub_module) |
80 | | - with original_gm.graph.inserting_after(get_body_nodes()[-1]): |
81 | | - submodule_node = original_gm.graph.call_module( |
82 | | - submodule_name, tuple(get_input_nodes()) |
| 58 | + for range_idx in range(len(submodule_ranges)): |
| 59 | + start_node_idx, end_node_idx = submodule_ranges[range_idx] |
| 60 | + ( |
| 61 | + submodule_input_nodes, |
| 62 | + submodule_output_nodes, |
| 63 | + ) = _get_submodule_inputs_and_outputs( |
| 64 | + original_gm=original_gm, |
| 65 | + start_node_idx=(num_placeholders + start_node_idx), |
| 66 | + end_node_idx=(num_placeholders + end_node_idx), |
83 | 67 | ) |
84 | | - prev_node = submodule_node |
85 | | - for idx, original_output in enumerate(get_output_nodes()): |
86 | | - with original_gm.graph.inserting_after(prev_node): |
87 | | - new_output_node = original_gm.graph.call_function( |
88 | | - operator.getitem, (submodule_node, idx) |
89 | | - ) |
90 | | - node_map[original_output] = new_output_node |
91 | | - prev_node = new_output_node |
92 | 68 |
|
93 | | - # Replace all use of outputs |
94 | | - for original_output in get_output_nodes(): |
95 | | - original_output.replace_all_uses_with(node_map[original_output]) |
| 69 | + def get_input_nodes(range_idx): |
| 70 | + return submodule_input_nodes |
96 | 71 |
|
97 | | - # Erase old nodes |
98 | | - for node in reversed(get_body_nodes()): |
99 | | - original_gm.graph.erase_node(node) |
| 72 | + def get_output_nodes(range_idx): |
| 73 | + return submodule_output_nodes |
| 74 | + |
| 75 | + submodule_name = ( |
| 76 | + f"{submodule_name_prefix}_{range_idx}" |
| 77 | + if range_idx > 0 |
| 78 | + else submodule_name_prefix |
| 79 | + ) |
| 80 | + new_graph = torch.fx.Graph() |
| 81 | + # Create a mapping for nodes from original graph to new graph |
| 82 | + node_map = {} |
| 83 | + |
| 84 | + # Add placeholder nodes for inputs |
| 85 | + for original_node in get_input_nodes(range_idx): |
| 86 | + new_node = new_graph.placeholder(original_node.name) |
| 87 | + node_map[original_node] = new_node |
| 88 | + |
| 89 | + # Copy body nodes |
| 90 | + for original_node in get_body_nodes(range_idx): |
| 91 | + new_node = new_graph.node_copy(original_node, lambda x: node_map[x]) |
| 92 | + node_map[original_node] = new_node |
| 93 | + |
| 94 | + # Add output nodes |
| 95 | + output_args = [] |
| 96 | + for original_node in get_output_nodes(range_idx): |
| 97 | + output_args.append(node_map[original_node]) |
| 98 | + new_graph.output(tuple(output_args)) |
| 99 | + |
| 100 | + # Create the new GraphModule |
| 101 | + # This assumes no submodules are being extracted, or they are handled separately |
| 102 | + new_sub_module = torch.fx.GraphModule(get_name2sub_submodule(), new_graph) |
| 103 | + if submodule_hook is not None: |
| 104 | + new_sub_module = submodule_hook(new_sub_module, range_idx) |
| 105 | + # Replace with submodule node |
| 106 | + original_gm.add_submodule(submodule_name, new_sub_module) |
| 107 | + with original_gm.graph.inserting_after(get_body_nodes(range_idx)[-1]): |
| 108 | + submodule_node = original_gm.graph.call_module( |
| 109 | + submodule_name, tuple(get_input_nodes(range_idx)) |
| 110 | + ) |
| 111 | + prev_node = submodule_node |
| 112 | + for idx, original_output in enumerate(get_output_nodes(range_idx)): |
| 113 | + with original_gm.graph.inserting_after(prev_node): |
| 114 | + new_output_node = original_gm.graph.call_function( |
| 115 | + operator.getitem, (submodule_node, idx) |
| 116 | + ) |
| 117 | + node_map[original_output] = new_output_node |
| 118 | + prev_node = new_output_node |
| 119 | + |
| 120 | + # Replace all use of outputs |
| 121 | + for original_output in get_output_nodes(range_idx): |
| 122 | + original_output.replace_all_uses_with(node_map[original_output]) |
| 123 | + |
| 124 | + # Erase old nodes |
| 125 | + for node in reversed(get_body_nodes(range_idx)): |
| 126 | + original_gm.graph.erase_node(node) |
100 | 127 |
|
101 | 128 | original_gm.recompile() |
102 | 129 |
|
103 | 130 | return original_gm |
104 | 131 |
|
105 | 132 |
|
| 133 | +def fold_range_to_submodule( |
| 134 | + original_gm: torch.fx.GraphModule, |
| 135 | + start_node_idx: int, |
| 136 | + end_node_idx: int, |
| 137 | + submodule_hook=None, |
| 138 | + submodule_name="extraced_submodule", |
| 139 | + group_head_and_tail=True, |
| 140 | +): |
| 141 | + return convert_to_submodules_graph( |
| 142 | + original_gm, |
| 143 | + split_positions=[start_node_idx, end_node_idx], |
| 144 | + submodule_hook=submodule_hook, |
| 145 | + submodule_name_prefix=submodule_name, |
| 146 | + group_head_and_tail=group_head_and_tail, |
| 147 | + ) |
| 148 | + |
| 149 | + |
106 | 150 | @dataclass |
107 | 151 | class NodeProducedOrConsumedCountCtx: |
108 | 152 | node2before_input: defaultdict(int) |
|
0 commit comments