@@ -82,6 +82,81 @@ def add_saved(a):
82
82
bw_module .graph .lint ()
83
83
return fw_module , bw_module
84
84
85
+ def partition_with_recompute_fwd_in_bwd (joint_module : fx .GraphModule , _joint_inputs ):
86
+ """
87
+ Partitions the joint graph such that the backward recomputes the forward.
88
+ Recopmuting helps in trading off memory bandwidth with computation.
89
+
90
+ To create the fwd and bwd graph, we copy the joint graph, manually set the
91
+ outputs to just original forward or backward outputs. And then we run the
92
+ resulting graphs through dead code elimintation.
93
+ """
94
+
95
+ def _extract_graph_with_given_outputs (joint_graph , outputs , is_fwd = False ):
96
+ """
97
+ Returns a copy of joint_graph with given outputs.
98
+
99
+ If its forward graph, we need extra bookkeeping
100
+ 1) Remove tangent nodes in the input.
101
+ 2) Pass the inputs directly to the output. This will be saved in the
102
+ backward ctx.
103
+ """
104
+ # Set up val_map to be used later for copying the graph
105
+ val_map = {}
106
+ saved_nodes = []
107
+ if is_fwd :
108
+ # Remove the tangent placeholder nodes from the graph
109
+ def _tangent_finder (node ):
110
+ return node .op == "placeholder" and "tangents" in node .target
111
+ tangent_nodes = filter (_tangent_finder , joint_graph .nodes )
112
+ for tangent_node in tangent_nodes :
113
+ val_map [tangent_node ] = 1
114
+
115
+ # Find the saved tensor nodes that will be used by ctx later
116
+ def _placeholder_finder (node ):
117
+ return node .op == "placeholder" and "tangents" not in node .target
118
+ saved_nodes = list (filter (_placeholder_finder , joint_graph .nodes ))
119
+
120
+ # Make a copy of the joint graph
121
+ graph = fx .Graph ()
122
+ graph .graph_copy (joint_graph , val_map )
123
+
124
+ # Set the outputs
125
+ outputs = outputs + saved_nodes
126
+ if len (outputs ) == 1 :
127
+ graph .output (val_map [outputs [0 ]])
128
+ else :
129
+ graph .output ([val_map [out ] for out in outputs ])
130
+
131
+ # Run dead code elimination to remove unnecessary nodes
132
+ graph .eliminate_dead_code ()
133
+ graph .lint ()
134
+ return graph
135
+
136
+ # Find the output node
137
+ output_node = None
138
+ for n in reversed (joint_module .graph .nodes ):
139
+ if n .op == "output" :
140
+ output_node = n
141
+ break
142
+
143
+ # Get the forward and backward output nodes
144
+ num_fwd_outputs = joint_module ._out_spec .children_specs [0 ].num_leaves
145
+ fwd_outputs = output_node .args [0 ][0 :num_fwd_outputs ]
146
+ bwd_outputs = output_node .args [0 ][num_fwd_outputs :]
147
+
148
+ # Construct the forward module
149
+ fwd_graph = _extract_graph_with_given_outputs (
150
+ joint_module .graph , fwd_outputs , is_fwd = True
151
+ )
152
+ fwd_module = fx .GraphModule (joint_module , fwd_graph )
153
+
154
+ # Construct the backward module
155
+ bwd_graph = _extract_graph_with_given_outputs (joint_module .graph , bwd_outputs )
156
+ bwd_module = fx .GraphModule (joint_module , bwd_graph )
157
+
158
+ return fwd_module , bwd_module
159
+
85
160
def create_joint_forward_backward (fn ):
86
161
def joint_forward_backward (primals , tangents ):
87
162
out = fn (* primals )
0 commit comments