@@ -79,25 +79,13 @@ def get_ops_for_key(key):
79
79
batched_registrations = get_ops_for_key ('FuncTorchBatched' )
80
80
all_ops = get_ops_for_key (None )
81
81
82
- # Find all occurrences of things inside of STOP_DECOMPOSE(...) using regex
83
- # Look in ../functorch/csrc/BatchRulesStopDecomposition.cpp
84
- # Example:
85
- # STOP_DECOMPOSE(sin); => sin
86
- with open ('../functorch/csrc/BatchRulesStopDecomposition.cpp' ) as f :
87
- content = f .read ()
88
- stop_decomposition_regex = re .compile (r'STOP_DECOMPOSE\((.*)\);' )
89
- stop_decomposition_matches = stop_decomposition_regex .findall (content )
90
- stop_decomposition_matches = [m .strip () for m in stop_decomposition_matches ]
91
- stop_decomposition_ops = set (stop_decomposition_matches )
92
-
93
82
composite_ops = get_ops_for_key ('CompositeImplicitAutograd' )
94
- decomposed_ops = composite_ops - stop_decomposition_ops
95
83
96
84
97
- vmap_ops = ( batched_registrations - stop_decomposition_ops ) | ( composite_ops - stop_decomposition_ops )
85
+ vmap_ops = batched_registrations
98
86
noncomposite_ops = all_ops - composite_ops
99
87
100
- ops = yaml .load (open ('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml' , 'r' ).read ())
88
+ ops = yaml .load (open ('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml' , 'r' ).read (), Loader = yaml . CLoader )
101
89
102
90
annotated_ops = {a .strip (): b .strip () for a ,b in list (csv .reader (open ('annotated_ops.txt' )))}
103
91
from collections import defaultdict
@@ -133,8 +121,6 @@ def annotate_ops(ops, is_unique):
133
121
categorization ['inplace' ] += 1
134
122
op ['meta' ] = 'inplace'
135
123
continue
136
- if 'slow_conv3d_backward.grad_input' in op ['full_name' ]:
137
- import pdb ; pdb .set_trace ()
138
124
if not is_unique and 'a!' in op ['func' ].lower ():
139
125
categorization ['out' ] += 1
140
126
op ['meta' ] = 'out'
0 commit comments