Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit d9de359

Browse files
authored
Update vmap codegen based on changes to torchgen (#898)
torchgen was changed sometime in the last month. This PR updates our codegen to work with the new changes. Coming soon: - we should be able to autogenerate our codegen and remove the checked in version.
1 parent 754ee26 commit d9de359

File tree

2 files changed

+19
-33
lines changed

2 files changed

+19
-33
lines changed

codegen/gen.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
import argparse
33
import pathlib
44

5+
import torchgen
56
from torchgen.gen import FileManager, parse_native_yaml
7+
from torchgen.gen import get_torchgen_root
68
from gen_vmap_plumbing import gen_all_vmap_plumbing
79

810
"""
@@ -14,12 +16,7 @@
1416
1517
Step 2: Run this script.
1618
17-
# Replace the last argument with your path to native_functions.yaml
18-
python codegen/gen.py -s /scratch/rzou/pt/debug-cpu/aten/src/ATen
19-
20-
NB: PyTorch's `tools` module is a giant hack (it somehow gets installed into your
21-
environment when one does python setup.py develop), but it's highly likely that
22-
PyTorch won't change it anytime soon because it's very messy to modify.
19+
python codegen/gen.py
2320
"""
2421

2522

@@ -29,42 +26,31 @@ def main() -> None:
2926
'-s',
3027
'--source-path',
3128
help='path to source directory for ATen',
32-
default='/scratch/rzou/pt/debug-cpu/aten/src/ATen')
33-
parser.add_argument(
34-
'-o',
35-
'--output-dependencies',
36-
help='output a list of dependencies into the given file and exit')
37-
parser.add_argument(
38-
'--dry-run', action='store_true',
39-
help='run without writing any files (still updates outputs)')
29+
default=None)
4030
parser.add_argument(
4131
'-d', '--install_dir', help='output directory',
4232
default='functorch/csrc')
4333
options = parser.parse_args()
34+
generate_code(options.install_dir, options.source_path)
35+
4436

45-
native_yaml_path = os.path.join(options.source_path, 'native/native_functions.yaml')
46-
parsed_yaml = parse_native_yaml(native_yaml_path)
37+
def generate_code(install_dir='functorch/csrc', source_path=None):
38+
if source_path is None:
39+
# infer the source path via torchgen
40+
source_path = os.path.join(get_torchgen_root(), "packaged/ATen")
41+
42+
native_yaml_path = os.path.join(source_path, 'native/native_functions.yaml')
43+
tags_path = os.path.join(source_path, 'native/tags.yaml')
44+
parsed_yaml = parse_native_yaml(native_yaml_path, tags_path)
4745
native_functions, _ = parsed_yaml.native_functions, parsed_yaml.backend_indices
48-
template_dir = os.path.join(options.source_path, "templates")
46+
template_dir = os.path.join(source_path, "templates")
4947

5048
def make_file_manager(install_dir: str) -> FileManager:
51-
return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=options.dry_run)
49+
return FileManager(install_dir=install_dir, template_dir=template_dir, dry_run=False)
5250

53-
cpu_fm = make_file_manager(options.install_dir)
51+
cpu_fm = make_file_manager(install_dir)
5452
cpu_fm.write('VmapGeneratedPlumbing.h', lambda: gen_all_vmap_plumbing(native_functions))
5553

56-
if options.output_dependencies:
57-
depfile_path = pathlib.Path(options.output_dependencies).resolve()
58-
depfile_name = depfile_path.name
59-
depfile_stem = depfile_path.stem
60-
61-
for fm, prefix in [
62-
(cpu_fm, ""),
63-
]:
64-
varname = prefix + depfile_stem
65-
path = depfile_path.parent / (prefix + depfile_name)
66-
fm.write_outputs(varname, str(path))
67-
6854

6955
if __name__ == '__main__':
7056
main()

codegen/gen_vmap_plumbing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
)
44
from torchgen.model import (
55
BaseTy, OptionalType, BaseType, ListType, NativeFunction, Type,
6-
Argument, Return, SchemaKind, Tag
6+
Argument, Return, SchemaKind
77
)
88
from torchgen.api.translate import translate
99
from torchgen.context import method_with_native_function
@@ -198,7 +198,7 @@ def gen_vmap_plumbing(native_function: NativeFunction) -> str:
198198
if not accepts_at_least_one_tensor_input(schema):
199199
return None
200200
# in-place views need special handling
201-
if native_function.tag == Tag.inplace_view:
201+
if 'inplace_view' in native_function.tags:
202202
return None
203203

204204
if schema.kind() == SchemaKind.inplace:

0 commit comments

Comments
 (0)