Skip to content

Commit 0ca049f

Browse files
committed
chore: use rn18 instead of rn50
Signed-off-by: Dheeraj Peri <[email protected]>
1 parent 3da78e9 commit 0ca049f

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

tests/cpp/test_modules_as_engines.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ INSTANTIATE_TEST_SUITE_P(
2424
ModuleAsEngineForwardIsCloseSuite,
2525
CppAPITests,
2626
testing::Values(
27-
PathAndInput({"tests/modules/resnet50_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 0.99}),
27+
PathAndInput({"tests/modules/resnet18_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 0.99}),
2828
PathAndInput({"tests/modules/mobilenet_v2_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 0.99}),
2929
PathAndInput({"tests/modules/efficientnet_b0_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 0.99}),
3030
PathAndInput({"tests/modules/vit_scripted.jit.pt", {{1, 3, 224, 224}}, {at::kFloat}, 0.99})));

tests/py/models/test_models.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@
1010

1111

1212
class TestModels(unittest.TestCase):
13-
def test_resnet50(self):
14-
self.model = models.resnet50(pretrained=True).eval().to("cuda")
13+
def test_resnet18(self):
14+
self.model = models.resnet18(pretrained=True).eval().to("cuda")
1515
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
1616

1717
compile_spec = {
@@ -120,8 +120,8 @@ def test_bert_base_uncased(self):
120120
msg=f"HF BERT base-uncased TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
121121
)
122122

123-
def test_resnet50_half(self):
124-
self.model = models.resnet50(pretrained=True).eval().to("cuda")
123+
def test_resnet18_half(self):
124+
self.model = models.resnet18(pretrained=True).eval().to("cuda")
125125
self.input = torch.randn((1, 3, 224, 224)).to("cuda")
126126
self.scripted_model = torch.jit.script(self.model)
127127
self.scripted_model.half()

0 commit comments

Comments
 (0)