We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent f2045c4 commit cf0e73fCopy full SHA for cf0e73f
ai_edge_torch/debug/test/test_culprit.py
@@ -15,14 +15,14 @@
15
16
17
import ast
18
-import io
19
-import sys
20
21
-from ai_edge_torch.debug import find_culprits
+import ai_edge_torch.debug
22
import torch
23
24
from absl.testing import absltest as googletest
25
+find_culprits = ai_edge_torch.debug.find_culprits
+
26
_test_culprit_lib = torch.library.Library("test_culprit", "DEF")
27
28
_test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
@@ -52,6 +52,11 @@ def forward(self, x):
52
53
class TestCulprit(googletest.TestCase):
54
55
+ def setUp(self):
56
+ super().setUp()
57
+ torch.manual_seed(0)
58
+ torch._dynamo.reset()
59
60
def test_find_culprits(self):
61
model = BadModel().eval()
62
args = (torch.rand(10),)
0 commit comments