Skip to content

Commit 7617874

Browse files
authored
Use Functionalization pass (#810)
1 parent 88a4619 commit 7617874

File tree

4 files changed

+72
-0
lines changed

4 files changed

+72
-0
lines changed

functorch/_src/aot_autograd.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import torch.utils.dlpack
99
from torch.nn.utils import _stateless
1010
from functorch._C import CompileCache
11+
from functorch.experimental import functionalize
12+
from . import config
1113
from .decompositions import register_decomposition
1214
from .partitioners import default_partition
1315
from .named_members_polyfill import _named_parameters, _named_buffers
@@ -188,6 +190,12 @@ def forward(ctx, *flat_tensor_args):
188190
*joint_inputs
189191
)
190192

193+
if config.use_functionalize:
194+
# Functionalize the foward backward graph. First create a
195+
# fake fn to make functionalize happy
196+
def fake_fn(primals, tangents):
197+
return fx_g(primals, tangents)
198+
fx_g = make_fx(functionalize(fake_fn))(*joint_inputs)
191199
fw_module, bw_module = partition_fn(fx_g, joint_inputs)
192200
# print(fw_module.code, bw_module.code)
193201

functorch/_src/config.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
"""
8+
Global flags for aot autograd
9+
"""
10+
11+
use_functionalize = False

functorch/compile/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,4 @@
2727
draw_graph,
2828
draw_joint_graph,
2929
)
30+
from .._src import config

test/test_functionalize.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import torch
2+
3+
import functorch
4+
from torch.testing._internal.common_utils import run_tests, TestCase, IS_WINDOWS
5+
import unittest
6+
from unittest.mock import patch
7+
import functools
8+
9+
from functorch.compile import aot_function, nop
10+
import test_compile_cache
11+
import test_pythonkey
12+
13+
14+
def make_functionalize_fn(fn):
15+
@functools.wraps(fn)
16+
def _fn(*args, **kwargs):
17+
with patch.object(functorch.compile.config, "use_functionalize", True):
18+
return fn(*args, **kwargs)
19+
20+
return _fn
21+
22+
23+
def make_functionalize_test(cls):
24+
class FunctionalizeTest(cls):
25+
pass
26+
27+
FunctionalizeTest.__name__ = f"Functionalize{cls.__name__}"
28+
29+
for name in dir(cls):
30+
if name.startswith("test_"):
31+
fn = getattr(cls, name)
32+
if not callable(fn):
33+
continue
34+
35+
new_name = f"{name}_functionalize"
36+
fn = make_functionalize_fn(fn)
37+
fn.__name__ = new_name
38+
setattr(FunctionalizeTest, name, None)
39+
setattr(FunctionalizeTest, new_name, fn)
40+
41+
return FunctionalizeTest
42+
43+
44+
FunctionalizeTestCompileCache = make_functionalize_test(test_compile_cache.TestCompileCache)
45+
FunctionalizeTestCompileCacheStaticArgs = make_functionalize_test(test_compile_cache.TestCompileCacheStaticArgs)
46+
FunctionalizeTestPythonKeyAOT = make_functionalize_test(test_pythonkey.TestAOTAutograd)
47+
FunctionalizeTestPythonKeyContiguous = make_functionalize_test(test_pythonkey.TestContiguous)
48+
FunctionalizeTestPythonKeyRandom = make_functionalize_test(test_pythonkey.TestRandom)
49+
FunctionalizeTestPythonKeyPartitioning = make_functionalize_test(test_pythonkey.TestPartitioning)
50+
51+
if __name__ == "__main__":
52+
run_tests()

0 commit comments

Comments
 (0)