Skip to content

Commit 4aab57b

Browse files
committed
update discover_coverage
1 parent d35dca0 commit 4aab57b

File tree

1 file changed

+36
-9
lines changed

1 file changed

+36
-9
lines changed

test/discover_coverage.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import torch
22
import copy
33
from torch.testing._internal.common_methods_invocations import op_db
4+
from functorch_additional_op_db import additional_op_db
45
from enum import Enum
56
import functorch._src.top_operators_github_usage as top_ops
67
import pprint
@@ -546,7 +547,7 @@ def print_coverage_info(th=100, nn=25):
546547

547548
def get_name_to_opinfo_map():
548549
dct = {}
549-
for op in op_db:
550+
for op in (op_db + additional_op_db):
550551
def add(name, op):
551552
if name not in dct:
552553
dct[name] = []
@@ -571,6 +572,12 @@ class Support(enum.Enum):
571572
'full', 'randperm', 'eye', 'randint', 'linspace', 'logspace',
572573
}
573574

575+
VJP_EXEMPTIONS = {
576+
'nn.functional.dropout', # not actually problem, randomness testing artifact
577+
'nn.functional.dropout2d', # not actually problem, randomness testing artifact
578+
'nn.functional.rrelu', # not actually problem, randomness testing artifact
579+
}
580+
574581
VMAP_EXEMPTIONS = {
575582
'randn_like', # randomness
576583
'rand_like', # randomness
@@ -583,10 +590,14 @@ class Support(enum.Enum):
583590
'svd', # There isn't a bug, it is just nondeterministic so we can't test it.
584591
'nn.functional.embedding', # We support everything except the sparse option.
585592
'nn.functional.dropout', # randomness
593+
'nn.functional.dropout2d', # randomness
594+
'bernoulli', # randomness
595+
'multinomial', # randomness
596+
'normal', # randomness
586597
}
587598

588599
JVP_EXEMPTIONS = {
589-
'nn.functional.dropout', # not actually problem, randomness testing artifact
600+
'nn.functional.dropout2d', # not actually problem, randomness testing artifact
590601
'nn.functional.rrelu', # not actually problem, randomness testing artifact
591602
# 'normal',
592603
# 'bernoulli',
@@ -613,8 +624,9 @@ def no_opinfos_skip_test(self, test_name):
613624
"""Returns NO if any opinfos have a skip or xfail for the test"""
614625
if not self.has_opinfo():
615626
return Support.UNKNOWN
616-
if not any([in_functorch_lagging_op_db(o) for o in self.opinfos]):
617-
return Support.UNKNOWN
627+
if not any([o in additional_op_db for o in self.opinfos]):
628+
if not any([in_functorch_lagging_op_db(o) for o in self.opinfos]):
629+
return Support.UNKNOWN
618630
for opinfo in self.opinfos:
619631
for decorator in opinfo.decorators:
620632
if not hasattr(decorator, 'test_name'):
@@ -638,6 +650,8 @@ def all_opinfo_attr(self, attr):
638650
def supports_vjp(self):
639651
if self.name in FACTORY_FNS:
640652
return Support.YES
653+
if self.name in VJP_EXEMPTIONS:
654+
return Support.YES
641655
return self.no_opinfos_skip_test('test_vjp')
642656

643657
def supports_vmap(self):
@@ -696,8 +710,6 @@ def supports_jvpvjp(self):
696710
def _supports_vmapjvp_base(self, test):
697711
if self.name in FACTORY_FNS:
698712
return Support.YES
699-
if self.name in VMAP_EXEMPTIONS:
700-
return Support.YES
701713
if self.name in JVP_EXEMPTIONS:
702714
return Support.YES
703715
if not self.has_opinfo():
@@ -794,12 +806,27 @@ def summary(self):
794806
result = opset.query(Operator.supports_vjp, (Support.NO, Support.UNKNOWN))
795807
# pprint.pprint(result)
796808

809+
print("=" * 30 + " Top 60 Summary " + "=" * 30)
810+
opset = OperatorSet.from_top_ops_threshold(35, 25)
811+
result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
812+
pprint.pprint(result)
813+
result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN))
814+
pprint.pprint(result)
815+
#kresult = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
816+
#kpprint.pprint(result)
817+
# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
818+
# pprint.pprint(result)
819+
# result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN))
820+
# pprint.pprint(result)
821+
# pprint.pprint(result)
822+
print(opset.summary())
823+
797824
print("=" * 30 + " Top 125 Summary " + "=" * 30)
798825
opset = OperatorSet.from_top125()
799-
# result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN))
800-
# pprint.pprint(result)
801-
result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
826+
result = opset.query(Operator.supports_vmap, (Support.NO, Support.UNKNOWN))
802827
pprint.pprint(result)
828+
#kresult = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
829+
#kpprint.pprint(result)
803830
# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
804831
# pprint.pprint(result)
805832
# result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN))

0 commit comments

Comments
 (0)