2
2
import argparse
3
3
import pathlib
4
4
5
+ import torchgen
5
6
from torchgen .gen import FileManager , parse_native_yaml
7
+ from torchgen .gen import get_torchgen_root
6
8
from gen_vmap_plumbing import gen_all_vmap_plumbing
7
9
8
10
"""
14
16
15
17
Step 2: Run this script.
16
18
17
- # Replace the last argument with your path to native_functions.yaml
18
- python codegen/gen.py -s /scratch/rzou/pt/debug-cpu/aten/src/ATen
19
-
20
- NB: PyTorch's `tools` module is a giant hack (it somehow gets installed into your
21
- environment when one does python setup.py develop), but it's highly likely that
22
- PyTorch won't change it anytime soon because it's very messy to modify.
19
+ python codegen/gen.py
23
20
"""
24
21
25
22
@@ -29,42 +26,31 @@ def main() -> None:
29
26
'-s' ,
30
27
'--source-path' ,
31
28
help = 'path to source directory for ATen' ,
32
- default = '/scratch/rzou/pt/debug-cpu/aten/src/ATen' )
33
- parser .add_argument (
34
- '-o' ,
35
- '--output-dependencies' ,
36
- help = 'output a list of dependencies into the given file and exit' )
37
- parser .add_argument (
38
- '--dry-run' , action = 'store_true' ,
39
- help = 'run without writing any files (still updates outputs)' )
29
+ default = None )
40
30
parser .add_argument (
41
31
'-d' , '--install_dir' , help = 'output directory' ,
42
32
default = 'functorch/csrc' )
43
33
options = parser .parse_args ()
34
+ generate_code (options .install_dir , options .source_path )
35
+
44
36
45
- native_yaml_path = os .path .join (options .source_path , 'native/native_functions.yaml' )
46
- parsed_yaml = parse_native_yaml (native_yaml_path )
37
+ def generate_code (install_dir = 'functorch/csrc' , source_path = None ):
38
+ if source_path is None :
39
+ # infer the source path via torchgen
40
+ source_path = os .path .join (get_torchgen_root (), "packaged/ATen" )
41
+
42
+ native_yaml_path = os .path .join (source_path , 'native/native_functions.yaml' )
43
+ tags_path = os .path .join (source_path , 'native/tags.yaml' )
44
+ parsed_yaml = parse_native_yaml (native_yaml_path , tags_path )
47
45
native_functions , _ = parsed_yaml .native_functions , parsed_yaml .backend_indices
48
- template_dir = os .path .join (options . source_path , "templates" )
46
+ template_dir = os .path .join (source_path , "templates" )
49
47
50
48
def make_file_manager (install_dir : str ) -> FileManager :
51
- return FileManager (install_dir = install_dir , template_dir = template_dir , dry_run = options . dry_run )
49
+ return FileManager (install_dir = install_dir , template_dir = template_dir , dry_run = False )
52
50
53
- cpu_fm = make_file_manager (options . install_dir )
51
+ cpu_fm = make_file_manager (install_dir )
54
52
cpu_fm .write ('VmapGeneratedPlumbing.h' , lambda : gen_all_vmap_plumbing (native_functions ))
55
53
56
- if options .output_dependencies :
57
- depfile_path = pathlib .Path (options .output_dependencies ).resolve ()
58
- depfile_name = depfile_path .name
59
- depfile_stem = depfile_path .stem
60
-
61
- for fm , prefix in [
62
- (cpu_fm , "" ),
63
- ]:
64
- varname = prefix + depfile_stem
65
- path = depfile_path .parent / (prefix + depfile_name )
66
- fm .write_outputs (varname , str (path ))
67
-
68
54
69
55
if __name__ == '__main__' :
70
56
main ()
0 commit comments