diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index acae618f1b..1ea62495f1 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -4,7 +4,7 @@ import logging import platform from enum import Enum -from typing import Any, Callable, List, Optional, Sequence, Set, Union +from typing import Any, Callable, List, Optional, Sequence, Set, Union, Literal import torch import torch.fx @@ -580,7 +580,9 @@ def save( module: Any, file_path: str = "", *, - output_format: str = "exported_program", + output_format: Literal[ + "exported_program", "torchscript", "aot_inductor" + ] = "exported_program", inputs: Optional[Sequence[torch.Tensor]] = None, arg_inputs: Optional[Sequence[torch.Tensor]] = None, kwarg_inputs: Optional[dict[str, Any]] = None, @@ -639,7 +641,7 @@ def save( "Input model is of type nn.Module. Saving nn.Module directly is not supported. Supported model types torch.jit.ScriptModule | torch.fx.GraphModule | torch.export.ExportedProgram." ) elif module_type == _ModuleType.ts: - if not all([output_format == f for f in ["exported_program", "aot_inductor"]]): + if output_format != "torchscript": raise ValueError( "Provided model is a torch.jit.ScriptModule but the output_format specified is not torchscript. Other output formats are not supported" ) diff --git a/tests/py/ts/api/test_export_serde.py b/tests/py/ts/api/test_export_serde.py new file mode 100644 index 0000000000..21bf57082b --- /dev/null +++ b/tests/py/ts/api/test_export_serde.py @@ -0,0 +1,58 @@ +import importlib +import os +import platform +import tempfile +import unittest + +import pytest +import torch +import torch_tensorrt as torchtrt +from torch_tensorrt.dynamo.utils import ( + COSINE_THRESHOLD, + cosine_similarity, + get_model_device, +) + +assertions = unittest.TestCase() + +@pytest.mark.unit +def test_save_load_ts(ir): + """ + This tests save/load API on Torchscript format (model still compiled using ts workflow) + """ + + class MyModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True) + self.relu = torch.nn.ReLU() + + def forward(self, x): + conv = self.conv(x) + relu = self.relu(conv) + mul = relu * 0.5 + return mul + + model = MyModule().eval().cuda() + input = torch.randn((1, 3, 224, 224)).to("cuda") + + trt_gm = torchtrt.compile( + model, + ir="ts", + inputs=[input], + min_block_size=1, + cache_built_engines=False, + reuse_cached_engines=False, + ) + outputs_trt = trt_gm(input) + # Save it as torchscript representation + torchtrt.save(trt_gm, "./trt.ts", output_format="torchscript", inputs=[input]) + + trt_ts_module = torchtrt.load("./trt.ts") + outputs_trt_deser = trt_ts_module(input) + + cos_sim = cosine_similarity(outputs_trt, outputs_trt_deser) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_save_load_ts TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + )