diff --git a/docs/source-pytorch/deploy/production_advanced_2.rst b/docs/source-pytorch/deploy/production_advanced_2.rst index 766f80f1094ff..ba0b69b46def5 100644 --- a/docs/source-pytorch/deploy/production_advanced_2.rst +++ b/docs/source-pytorch/deploy/production_advanced_2.rst @@ -7,15 +7,20 @@ Deploy models into production (advanced) ---- -********************************* -Compile your model to TorchScript -********************************* -`TorchScript `_ allows you to serialize your models in a way that it can be loaded in non-Python environments. -The ``LightningModule`` has a handy method :meth:`~lightning.pytorch.core.LightningModule.to_torchscript` that returns a scripted module which you -can save or directly use. +************************************ +Export your model with torch.export +************************************ + +`torch.export `_ is the recommended way to capture PyTorch models for +deployment in production environments. It produces a clean intermediate representation with strong soundness guarantees, +making models suitable for inference optimization and cross-platform deployment. +You can export any ``LightningModule`` using the ``torch.export.export()`` API. .. testcode:: python + import torch + from torch.export import export + class SimpleModel(LightningModule): def __init__(self): super().__init__() @@ -25,25 +30,27 @@ can save or directly use. return torch.relu(self.l1(x.view(x.size(0), -1))) - # create the model + # create the model and example input model = SimpleModel() - script = model.to_torchscript() + example_input = torch.randn(1, 64) - # save for use in production environment - torch.jit.save(script, "model.pt") + # export the model + exported_program = export(model, (example_input,)) -It is recommended that you install the latest supported version of PyTorch to use this feature without limitations. + # save for use in production environment + torch.export.save(exported_program, "model.pt2") -Once you have the exported model, you can run it in PyTorch or C++ runtime: +It is recommended that you install the latest supported version of PyTorch to use this feature without +limitations. Once you have the exported model, you can load and run it: .. code-block:: python inp = torch.rand(1, 64) - scripted_module = torch.jit.load("model.pt") - output = scripted_module(inp) + loaded_program = torch.export.load("model.pt2") + output = loaded_program.module()(inp) -If you want to script a different method, you can decorate the method with :func:`torch.jit.export`: +For more complex models, you can also export specific methods by creating a wrapper: .. code-block:: python @@ -54,7 +61,6 @@ If you want to script a different method, you can decorate the method with :func self.dropout = nn.Dropout() self.mc_iteration = mc_iteration - @torch.jit.export def predict_step(self, batch, batch_idx): # enable Monte Carlo Dropout self.dropout.train() @@ -66,4 +72,11 @@ If you want to script a different method, you can decorate the method with :func model = LitMCdropoutModel(...) - script = model.to_torchscript(file_path="model.pt", method="script") + example_batch = torch.randn(32, 10) # example input + + # Export the predict_step method + exported_program = torch.export.export( + lambda batch, idx: model.predict_step(batch, idx), + (example_batch, 0) + ) + torch.export.save(exported_program, "mc_dropout_model.pt2")