Skip to content
This repository was archived by the owner on Aug 21, 2025. It is now read-only.

Commit 0a8647c

Browse files
committed
Update coverage
1 parent aa2cd89 commit 0a8647c

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

test/discover_coverage.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -606,6 +606,9 @@ def has_opinfo(self):
606606
def __repr__(self):
607607
return f'Operator("{self.name}")'
608608

609+
def __hash__(self):
610+
return hash(self.name)
611+
609612
def no_opinfos_skip_test(self, test_name):
610613
"""Returns NO if any opinfos have a skip or xfail for the test"""
611614
if not self.has_opinfo():
@@ -680,6 +683,12 @@ def supports_jvp(self):
680683
def supports_jvpvjp(self):
681684
if self.name in FACTORY_FNS:
682685
return Support.YES
686+
exemptions = {
687+
# we have support (see OpInfo), testing artifact
688+
'torch.nn.functional.dropout2d',
689+
}
690+
if self.name in exemptions:
691+
return Support.YES
683692
return self.no_opinfos_skip_test('test_jvpvjp')
684693

685694
def _supports_vmapjvp_base(self, test):
@@ -785,15 +794,19 @@ def summary(self):
785794

786795
print("=" * 30 + " Top 125 Summary " + "=" * 30)
787796
opset = OperatorSet.from_top125()
788-
result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN))
789-
pprint.pprint(result)
797+
# result = opset.query(Operator.supports_jvp, (Support.NO, Support.UNKNOWN))
798+
# pprint.pprint(result)
790799
result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
791800
pprint.pprint(result)
801+
# result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
802+
# pprint.pprint(result)
803+
# result = opset.query(Operator.supports_fast_vmapjvp, (Support.NO, Support.UNKNOWN))
804+
# pprint.pprint(result)
792805
# pprint.pprint(result)
793806
print(opset.summary())
794807

795808
print("=" * 30 + " Top 160 Summary " + "=" * 30)
796809
opset = OperatorSet.from_top160()
797-
result = opset.query(Operator.supports_vmapjvp, (Support.NO, Support.UNKNOWN))
798-
# pprint.pprint(result)
810+
result = opset.query(Operator.supports_jvpvjp, (Support.NO, Support.UNKNOWN))
811+
pprint.pprint(result)
799812
print(opset.summary())

0 commit comments

Comments
 (0)