Skip to content

Commit 9bb2684

Browse files
author
liangtao07
committed
Add test for module type.
Signed-off-by: liangtao07 <[email protected]>
1 parent 5335fab commit 9bb2684

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

tests/py/test_api.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -485,6 +485,19 @@ def test_dynamic_shape(self):
485485
i = torchtrt.Input(min_shape=tensor_shape(min_shape), opt_shape=tensor_shape(opt_shape), max_shape=tensor_shape(max_shape))
486486
self.assertTrue(self._verify_correctness(i, target))
487487

488+
489+
class TestModule(unittest.TestCase):
490+
491+
def test_module_type(self):
492+
nn_module = models.alexnet(pretrained=True).eval().to("cuda")
493+
ts_module = torch.jit.trace(nn_module, torch.ones([1, 3, 224, 224]).to("cuda"))
494+
fx_module = torch.fx.symbolic_trace(nn_module)
495+
496+
self.assertEqual(torchtrt._compile._parse_module_type(nn_module), torchtrt._compile._ModuleType.nn)
497+
self.assertEqual(torchtrt._compile._parse_module_type(ts_module), torchtrt._compile._ModuleType.ts)
498+
self.assertEqual(torchtrt._compile._parse_module_type(fx_module), torchtrt._compile._ModuleType.fx)
499+
500+
488501
def test_suite():
489502
suite = unittest.TestSuite()
490503
suite.addTest(unittest.makeSuite(TestLoggingAPIs))
@@ -505,6 +518,7 @@ def test_suite():
505518
suite.addTest(unittest.makeSuite(TestCheckMethodOpSupport))
506519
suite.addTest(unittest.makeSuite(TestDevice))
507520
suite.addTest(unittest.makeSuite(TestInput))
521+
suite.addTest(unittest.makeSuite(TestModule))
508522

509523
return suite
510524

0 commit comments

Comments
 (0)