Skip to content

Commit cf0e73f

Browse files
chunnienccopybara-github
authored andcommitted
Fix culprit test
PiperOrigin-RevId: 704798083
1 parent f2045c4 commit cf0e73f

File tree

1 file changed

+8
-3
lines changed

1 file changed

+8
-3
lines changed

ai_edge_torch/debug/test/test_culprit.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515

1616

1717
import ast
18-
import io
19-
import sys
2018

21-
from ai_edge_torch.debug import find_culprits
19+
import ai_edge_torch.debug
2220
import torch
2321

2422
from absl.testing import absltest as googletest
2523

24+
find_culprits = ai_edge_torch.debug.find_culprits
25+
2626
_test_culprit_lib = torch.library.Library("test_culprit", "DEF")
2727

2828
_test_culprit_lib.define("non_lowerable_op(Tensor x) -> Tensor")
@@ -52,6 +52,11 @@ def forward(self, x):
5252

5353
class TestCulprit(googletest.TestCase):
5454

55+
def setUp(self):
56+
super().setUp()
57+
torch.manual_seed(0)
58+
torch._dynamo.reset()
59+
5560
def test_find_culprits(self):
5661
model = BadModel().eval()
5762
args = (torch.rand(10),)

0 commit comments

Comments
 (0)