Skip to content

Commit ab4a32f

Browse files
authored
[Partitioning] Recompute forward in the backward pass (#213)
Summary: Recomputation fwd in the bwd pass can improve the performance of pointwise operators, where it helps us in reduce memory bandwidth pressure at the expense of more computation. This PR adds a new partitioning function to enable this type of recomputation.
1 parent 157688a commit ab4a32f

File tree

3 files changed

+100
-2
lines changed

3 files changed

+100
-2
lines changed

functorch/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from ._src.make_functional import functional_init, functional_init_with_buffers
2222
from ._src.python_key import wrap_key, PythonTensor, pythonkey_trace, make_fx, nnc_jit, make_nnc
2323
from ._src.nnc_compile import nnc_compile, get_ops
24-
from ._src.eager_compilation import compiled_function, compiled_module, tvm_compile, draw_joint_graph, default_partition
24+
from ._src.eager_compilation import compiled_function, compiled_module, tvm_compile, draw_joint_graph, default_partition, partition_with_recompute_fwd_in_bwd
2525
from ._src.operator_authoring import pointwise_operator
2626

2727

functorch/_src/eager_compilation.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,81 @@ def add_saved(a):
8282
bw_module.graph.lint()
8383
return fw_module, bw_module
8484

85+
def partition_with_recompute_fwd_in_bwd(joint_module: fx.GraphModule, _joint_inputs):
86+
"""
87+
Partitions the joint graph such that the backward recomputes the forward.
88+
Recopmuting helps in trading off memory bandwidth with computation.
89+
90+
To create the fwd and bwd graph, we copy the joint graph, manually set the
91+
outputs to just original forward or backward outputs. And then we run the
92+
resulting graphs through dead code elimintation.
93+
"""
94+
95+
def _extract_graph_with_given_outputs(joint_graph, outputs, is_fwd=False):
96+
"""
97+
Returns a copy of joint_graph with given outputs.
98+
99+
If its forward graph, we need extra bookkeeping
100+
1) Remove tangent nodes in the input.
101+
2) Pass the inputs directly to the output. This will be saved in the
102+
backward ctx.
103+
"""
104+
# Set up val_map to be used later for copying the graph
105+
val_map = {}
106+
saved_nodes = []
107+
if is_fwd:
108+
# Remove the tangent placeholder nodes from the graph
109+
def _tangent_finder(node):
110+
return node.op == "placeholder" and "tangents" in node.target
111+
tangent_nodes = filter(_tangent_finder, joint_graph.nodes)
112+
for tangent_node in tangent_nodes:
113+
val_map[tangent_node] = 1
114+
115+
# Find the saved tensor nodes that will be used by ctx later
116+
def _placeholder_finder(node):
117+
return node.op == "placeholder" and "tangents" not in node.target
118+
saved_nodes = list(filter(_placeholder_finder, joint_graph.nodes))
119+
120+
# Make a copy of the joint graph
121+
graph = fx.Graph()
122+
graph.graph_copy(joint_graph, val_map)
123+
124+
# Set the outputs
125+
outputs = outputs + saved_nodes
126+
if len(outputs) == 1:
127+
graph.output(val_map[outputs[0]])
128+
else:
129+
graph.output([val_map[out] for out in outputs])
130+
131+
# Run dead code elimination to remove unnecessary nodes
132+
graph.eliminate_dead_code()
133+
graph.lint()
134+
return graph
135+
136+
# Find the output node
137+
output_node = None
138+
for n in reversed(joint_module.graph.nodes):
139+
if n.op == "output":
140+
output_node = n
141+
break
142+
143+
# Get the forward and backward output nodes
144+
num_fwd_outputs = joint_module._out_spec.children_specs[0].num_leaves
145+
fwd_outputs = output_node.args[0][0:num_fwd_outputs]
146+
bwd_outputs = output_node.args[0][num_fwd_outputs:]
147+
148+
# Construct the forward module
149+
fwd_graph = _extract_graph_with_given_outputs(
150+
joint_module.graph, fwd_outputs, is_fwd=True
151+
)
152+
fwd_module = fx.GraphModule(joint_module, fwd_graph)
153+
154+
# Construct the backward module
155+
bwd_graph = _extract_graph_with_given_outputs(joint_module.graph, bwd_outputs)
156+
bwd_module = fx.GraphModule(joint_module, bwd_graph)
157+
158+
return fwd_module, bwd_module
159+
85160
def create_joint_forward_backward(fn):
86161
def joint_forward_backward(primals, tangents):
87162
out = fn(*primals)

test/test_pythonkey.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323
import functorch
2424
from functorch import (
2525
grad, vjp, vmap, jacrev, grad_and_value,
26-
make_functional_deprecated_v1, make_functional_with_buffers_deprecated_v1, make_fx, nnc_jit, compiled_function, compiled_module
26+
make_functional_deprecated_v1, make_functional_with_buffers_deprecated_v1, make_fx, nnc_jit, compiled_function, compiled_module,
27+
partition_with_recompute_fwd_in_bwd
2728
)
2829

2930
from torch.testing._internal.common_device_type import ops, onlyCPU
@@ -365,6 +366,28 @@ def create_new_arg(x):
365366
self.assertEqual(orig_grad, compiled_grad)
366367

367368

369+
class TestPartitioning(TestCase):
370+
def test_recompute_partitioning(self):
371+
def fn(a, b):
372+
return torch.sin(torch.sin(a)) + b
373+
374+
# Reference calculation
375+
ref_a = torch.rand(10, 10, requires_grad=True)
376+
ref_b = torch.rand(10, 10, requires_grad=True)
377+
ref = fn(ref_a, ref_b)
378+
ref.sum().backward()
379+
380+
# Compiled function calculation
381+
res_a = ref_a.clone().detach().requires_grad_(True)
382+
res_b = ref_b.clone().detach().requires_grad_(True)
383+
compile_fn = lambda x, _ : x
384+
compiled_fn = compiled_function(fn, compile_fn, compile_fn, partition_with_recompute_fwd_in_bwd)
385+
res = compiled_fn(res_a, res_b)
386+
res.sum().backward()
387+
assert torch.allclose(ref, res, atol=1e-3, rtol=1e-3)
388+
assert torch.allclose(ref_a.grad, res_a.grad, atol=1e-3, rtol=1e-3)
389+
assert torch.allclose(ref_b.grad, res_b.grad, atol=1e-3, rtol=1e-3)
390+
368391

369392
only_for = ("cpu")
370393
instantiate_device_type_tests(

0 commit comments

Comments
 (0)