Skip to content

Commit fc5abd4

Browse files
committed
updated gen_data script
1 parent 08f1a65 commit fc5abd4

File tree

2 files changed

+116
-104
lines changed

2 files changed

+116
-104
lines changed

op_analysis/gen_data.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -79,25 +79,13 @@ def get_ops_for_key(key):
7979
batched_registrations = get_ops_for_key('FuncTorchBatched')
8080
all_ops = get_ops_for_key(None)
8181

82-
# Find all occurrences of things inside of STOP_DECOMPOSE(...) using regex
83-
# Look in ../functorch/csrc/BatchRulesStopDecomposition.cpp
84-
# Example:
85-
# STOP_DECOMPOSE(sin); => sin
86-
with open('../functorch/csrc/BatchRulesStopDecomposition.cpp') as f:
87-
content = f.read()
88-
stop_decomposition_regex = re.compile(r'STOP_DECOMPOSE\((.*)\);')
89-
stop_decomposition_matches = stop_decomposition_regex.findall(content)
90-
stop_decomposition_matches = [m.strip() for m in stop_decomposition_matches]
91-
stop_decomposition_ops = set(stop_decomposition_matches)
92-
9382
composite_ops = get_ops_for_key('CompositeImplicitAutograd')
94-
decomposed_ops = composite_ops - stop_decomposition_ops
9583

9684

97-
vmap_ops = (batched_registrations - stop_decomposition_ops) | (composite_ops - stop_decomposition_ops)
85+
vmap_ops = batched_registrations
9886
noncomposite_ops = all_ops - composite_ops
9987

100-
ops = yaml.load(open('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read())
88+
ops = yaml.load(open('/home/chilli/fb/pytorch/aten/src/ATen/native/native_functions.yaml', 'r').read(), Loader=yaml.CLoader)
10189

10290
annotated_ops = {a.strip(): b.strip() for a,b in list(csv.reader(open('annotated_ops.txt')))}
10391
from collections import defaultdict
@@ -133,8 +121,6 @@ def annotate_ops(ops, is_unique):
133121
categorization['inplace'] += 1
134122
op['meta'] = 'inplace'
135123
continue
136-
if 'slow_conv3d_backward.grad_input' in op['full_name']:
137-
import pdb; pdb.set_trace()
138124
if not is_unique and 'a!' in op['func'].lower():
139125
categorization['out'] += 1
140126
op['meta'] = 'out'

0 commit comments

Comments
 (0)