1
1
import torch
2
2
import copy
3
3
from torch .testing ._internal .common_methods_invocations import op_db
4
+ from functorch_additional_op_db import additional_op_db
4
5
from enum import Enum
5
6
import functorch ._src .top_operators_github_usage as top_ops
6
7
import pprint
@@ -546,7 +547,7 @@ def print_coverage_info(th=100, nn=25):
546
547
547
548
def get_name_to_opinfo_map ():
548
549
dct = {}
549
- for op in op_db :
550
+ for op in ( op_db + additional_op_db ) :
550
551
def add (name , op ):
551
552
if name not in dct :
552
553
dct [name ] = []
@@ -571,6 +572,12 @@ class Support(enum.Enum):
571
572
'full' , 'randperm' , 'eye' , 'randint' , 'linspace' , 'logspace' ,
572
573
}
573
574
575
+ VJP_EXEMPTIONS = {
576
+ 'nn.functional.dropout' , # not actually problem, randomness testing artifact
577
+ 'nn.functional.dropout2d' , # not actually problem, randomness testing artifact
578
+ 'nn.functional.rrelu' , # not actually problem, randomness testing artifact
579
+ }
580
+
574
581
VMAP_EXEMPTIONS = {
575
582
'randn_like' , # randomness
576
583
'rand_like' , # randomness
@@ -583,10 +590,14 @@ class Support(enum.Enum):
583
590
'svd' , # There isn't a bug, it is just nondeterministic so we can't test it.
584
591
'nn.functional.embedding' , # We support everything except the sparse option.
585
592
'nn.functional.dropout' , # randomness
593
+ 'nn.functional.dropout2d' , # randomness
594
+ 'bernoulli' , # randomness
595
+ 'multinomial' , # randomness
596
+ 'normal' , # randomness
586
597
}
587
598
588
599
JVP_EXEMPTIONS = {
589
- 'nn.functional.dropout ' , # not actually problem, randomness testing artifact
600
+ 'nn.functional.dropout2d ' , # not actually problem, randomness testing artifact
590
601
'nn.functional.rrelu' , # not actually problem, randomness testing artifact
591
602
# 'normal',
592
603
# 'bernoulli',
@@ -613,8 +624,9 @@ def no_opinfos_skip_test(self, test_name):
613
624
"""Returns NO if any opinfos have a skip or xfail for the test"""
614
625
if not self .has_opinfo ():
615
626
return Support .UNKNOWN
616
- if not any ([in_functorch_lagging_op_db (o ) for o in self .opinfos ]):
617
- return Support .UNKNOWN
627
+ if not any ([o in additional_op_db for o in self .opinfos ]):
628
+ if not any ([in_functorch_lagging_op_db (o ) for o in self .opinfos ]):
629
+ return Support .UNKNOWN
618
630
for opinfo in self .opinfos :
619
631
for decorator in opinfo .decorators :
620
632
if not hasattr (decorator , 'test_name' ):
@@ -638,6 +650,8 @@ def all_opinfo_attr(self, attr):
638
650
def supports_vjp (self ):
639
651
if self .name in FACTORY_FNS :
640
652
return Support .YES
653
+ if self .name in VJP_EXEMPTIONS :
654
+ return Support .YES
641
655
return self .no_opinfos_skip_test ('test_vjp' )
642
656
643
657
def supports_vmap (self ):
@@ -696,8 +710,6 @@ def supports_jvpvjp(self):
696
710
def _supports_vmapjvp_base (self , test ):
697
711
if self .name in FACTORY_FNS :
698
712
return Support .YES
699
- if self .name in VMAP_EXEMPTIONS :
700
- return Support .YES
701
713
if self .name in JVP_EXEMPTIONS :
702
714
return Support .YES
703
715
if not self .has_opinfo ():
@@ -794,12 +806,27 @@ def summary(self):
794
806
result = opset .query (Operator .supports_vjp , (Support .NO , Support .UNKNOWN ))
795
807
# pprint.pprint(result)
796
808
809
+ print ("=" * 30 + " Top 60 Summary " + "=" * 30 )
810
+ opset = OperatorSet .from_top_ops_threshold (35 , 25 )
811
+ result = opset .query (Operator .supports_vmapjvp , (Support .NO , Support .UNKNOWN ))
812
+ pprint .pprint (result )
813
+ result = opset .query (Operator .supports_jvp , (Support .NO , Support .UNKNOWN ))
814
+ pprint .pprint (result )
815
+ #kresult = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
816
+ #kpprint.pprint(result)
817
+ # result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
818
+ # pprint.pprint(result)
819
+ # result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN))
820
+ # pprint.pprint(result)
821
+ # pprint.pprint(result)
822
+ print (opset .summary ())
823
+
797
824
print ("=" * 30 + " Top 125 Summary " + "=" * 30 )
798
825
opset = OperatorSet .from_top125 ()
799
- # result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN))
800
- # pprint.pprint(result)
801
- result = opset .query (Operator .supports_jvpvjp , (Support .NO , Support .UNKNOWN ))
826
+ result = opset .query (Operator .supports_vmap , (Support .NO , Support .UNKNOWN ))
802
827
pprint .pprint (result )
828
+ #kresult = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
829
+ #kpprint.pprint(result)
803
830
# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
804
831
# pprint.pprint(result)
805
832
# result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN))
0 commit comments