Skip to content

Commit 787d524

Browse files
authored
Make functorch codegen more robust (#693)
Previously the functorch codegen would only work if you didn't have a PyTorch develop install in your environment. This PR changes it so that the functorch codegen works when you have a PyTorch develop install in the environment. The reason for this change is that the PyTorch develop install adds a `tools` module into the environment. It turns out we can just rely on the tools module and this makes our codegen more robust to changes to pytorch/pytorch codegen (when compared to what we were doing before, which was keeping a copy of the PyTorch codegen inside of the functorch repo). Test Plan: - wait for tests
1 parent 10d9874 commit 787d524

36 files changed

+966
-9048
lines changed

codegen/gen.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import os
2+
import argparse
3+
import pathlib
4+
5+
from tools.codegen.gen import FileManager, parse_native_yaml
6+
from gen_vmap_plumbing import gen_all_vmap_plumbing
7+
8+
"""
9+
INSTRUCTIONS
10+
11+
Step 1: You must have a PyTorch installation (in develop mode, i.e.
12+
installed with python setup.py develop) in your current environment.
13+
This script relies on the `tools` module from the PyTorch develop installation.
14+
15+
Step 2: Run this script.
16+
17+
# Replace the last argument with your path to native_functions.yaml
18+
python codegen/gen.py -s /scratch/rzou/pt/debug-cpu/aten/src/ATen
19+
20+
NB: PyTorch's `tools` module is a giant hack (it somehow gets installed into your
21+
environment when one does python setup.py develop), but it's highly likely that
22+
PyTorch won't change it anytime soon because it's very messy to modify.
23+
"""
24+
25+
26+
def main() -> None:
27+
parser = argparse.ArgumentParser(description='functorch codegen')
28+
parser.add_argument(
29+
'-s',
30+
'--source-path',
31+
help='path to source directory for ATen',
32+
default='/scratch/rzou/pt/debug-cpu/aten/src/ATen')
33+
parser.add_argument(
34+
'-o',
35+
'--output-dependencies',
36+
help='output a list of dependencies into the given file and exit')
37+
parser.add_argument(
38+
'--dry-run', action='store_true',
39+
help='run without writing any files (still updates outputs)')
40+
parser.add_argument(
41+
'-d', '--install_dir', help='output directory',
42+
default='functorch/csrc')
43+
options = parser.parse_args()
44+
45+
native_yaml_path = os.path.join(options.source_path, 'native/native_functions.yaml')
46+
parsed_yaml = parse_native_yaml(native_yaml_path)
47+
native_functions, _ = parsed_yaml.native_functions, parsed_yaml.backend_indices
48+
template_dir = os.path.join(options.source_path, "templates")
49+
50+
def make_file_manager(install_dir: str) -> FileManager:
51+
return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run)
52+
53+
cpu_fm = make_file_manager(options.install_dir)
54+
cpu_fm.write('VmapGeneratedPlumbing.h', lambda: gen_all_vmap_plumbing(native_functions))
55+
56+
if options.output_dependencies:
57+
depfile_path = pathlib.Path(options.output_dependencies).resolve()
58+
depfile_name = depfile_path.name
59+
depfile_stem = depfile_path.stem
60+
61+
for fm, prefix in [
62+
(cpu_fm, ""),
63+
]:
64+
varname = prefix + depfile_stem
65+
path = depfile_path.parent / (prefix + depfile_name)
66+
fm.write_outputs(varname, str(path))
67+
68+
69+
if __name__ == '__main__':
70+
main()

tools/codegen/gen_vmap_plumbing.py renamed to codegen/gen_vmap_plumbing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22
DispatcherSignature,
33
)
44
from tools.codegen.model import (
5-
BaseTy, Variant, OptionalType, BaseType, ListType, NativeFunction, Type,
5+
BaseTy, OptionalType, BaseType, ListType, NativeFunction, Type,
66
Argument, Return, SchemaKind, Tag
77
)
88
from tools.codegen.api.translate import translate
99
from tools.codegen.context import method_with_native_function
1010
from tools.codegen.utils import mapMaybe
1111
from dataclasses import dataclass
12-
from typing import List, Dict, Optional, Tuple, Set, Any, Union, Sequence, TypeVar
12+
from typing import List, Optional, Tuple
1313
import textwrap
1414

1515

functorch/csrc/VmapGeneratedPlumbing.h

Lines changed: 882 additions & 76 deletions
Large diffs are not rendered by default.

test/discover_coverage.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ class Support(enum.Enum):
597597
}
598598

599599
JVP_EXEMPTIONS = {
600+
'nn.functional.dropout', # not actually problem, randomness testing artifact
600601
'nn.functional.dropout2d', # not actually problem, randomness testing artifact
601602
'nn.functional.rrelu', # not actually problem, randomness testing artifact
602603
# 'normal',
@@ -808,10 +809,10 @@ def summary(self):
808809

809810
print("=" * 30 + " Top 60 Summary " + "=" * 30)
810811
opset = OperatorSet.from_top_ops_threshold(35, 25)
811-
result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
812-
pprint.pprint(result)
813-
result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN))
814-
pprint.pprint(result)
812+
# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
813+
# pprint.pprint(result)
814+
# result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN))
815+
# pprint.pprint(result)
815816
# kresult = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
816817
# kpprint.pprint(result)
817818
# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
@@ -823,12 +824,14 @@ def summary(self):
823824

824825
print("=" * 30 + " Top 125 Summary " + "=" * 30)
825826
opset = OperatorSet.from_top125()
826-
result = opset.query(Operator.supports_vmap, (Support.NO, Support.UNKNOWN))
827-
pprint.pprint(result)
828-
# kresult = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
829-
# kpprint.pprint(result)
830-
# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
827+
# result = opset.query(Operator.supports_vmap, (Support.NO, Support.UNKNOWN))
828+
# pprint.pprint(result)
829+
# result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
831830
# pprint.pprint(result)
831+
result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN))
832+
pprint.pprint(result)
833+
result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
834+
pprint.pprint(result)
832835
# result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN))
833836
# pprint.pprint(result)
834837
# pprint.pprint(result)

test/test_ops.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,8 +381,6 @@ def wrapped_fn(*args, **kwargs):
381381
xfail('std_mean'),
382382
# https://gist.github.com/zou3519/f62a167fb46cda01d7f238f61dd9ccf9
383383
xfail('linalg.eigvalsh'),
384-
# https://gist.github.com/zou3519/b86616d01ca375a4bd17403277f49225
385-
xfail('nn.functional.dropout', device_type='cuda'),
386384
387385
# =============================================
388386
# NB: The above failures also fail using PyTorch core's

tools/README.md

Lines changed: 0 additions & 11 deletions
This file was deleted.

tools/codegen/__init__.py

Whitespace-only changes.

tools/codegen/api/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)