@@ -606,6 +606,9 @@ def has_opinfo(self):
606
606
def __repr__ (self ):
607
607
return f'Operator("{ self .name } ")'
608
608
609
+ def __hash__ (self ):
610
+ return hash (self .name )
611
+
609
612
def no_opinfos_skip_test (self , test_name ):
610
613
"""Returns NO if any opinfos have a skip or xfail for the test"""
611
614
if not self .has_opinfo ():
@@ -680,6 +683,12 @@ def supports_jvp(self):
680
683
def supports_jvpvjp (self ):
681
684
if self .name in FACTORY_FNS :
682
685
return Support .YES
686
+ exemptions = {
687
+ # we have support (see OpInfo), testing artifact
688
+ 'torch.nn.functional.dropout2d' ,
689
+ }
690
+ if self .name in exemptions :
691
+ return Support .YES
683
692
return self .no_opinfos_skip_test ('test_jvpvjp' )
684
693
685
694
def _supports_vmapjvp_base (self , test ):
@@ -785,15 +794,19 @@ def summary(self):
785
794
786
795
print ("=" * 30 + " Top 125 Summary " + "=" * 30 )
787
796
opset = OperatorSet .from_top125 ()
788
- result = opset .query (Operator .supports_jvp , (Support .NO , Support .UNKNOWN ))
789
- pprint .pprint (result )
797
+ # result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN))
798
+ # pprint.pprint(result)
790
799
result = opset .query (Operator .supports_jvpvjp , (Support .NO , Support .UNKNOWN ))
791
800
pprint .pprint (result )
801
+ # result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
802
+ # pprint.pprint(result)
803
+ # result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN))
804
+ # pprint.pprint(result)
792
805
# pprint.pprint(result)
793
806
print (opset .summary ())
794
807
795
808
print ("=" * 30 + " Top 160 Summary " + "=" * 30 )
796
809
opset = OperatorSet .from_top160 ()
797
- result = opset .query (Operator .supports_vmapjvp , (Support .NO , Support .UNKNOWN ))
798
- # pprint.pprint(result)
810
+ result = opset .query (Operator .supports_jvpvjp , (Support .NO , Support .UNKNOWN ))
811
+ pprint .pprint (result )
799
812
print (opset .summary ())
0 commit comments