Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 30 additions & 17 deletions docs/source-pytorch/deploy/production_advanced_2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,20 @@ Deploy models into production (advanced)

----

*********************************
Compile your model to TorchScript
*********************************
`TorchScript <https://pytorch.org/docs/stable/jit.html>`_ allows you to serialize your models in a way that it can be loaded in non-Python environments.
The ``LightningModule`` has a handy method :meth:`~lightning.pytorch.core.LightningModule.to_torchscript` that returns a scripted module which you
can save or directly use.
************************************
Export your model with torch.export
************************************

`torch.export <https://pytorch.org/docs/stable/export.html>`_ is the recommended way to capture PyTorch models for
deployment in production environments. It produces a clean intermediate representation with strong soundness guarantees,
making models suitable for inference optimization and cross-platform deployment.
You can export any ``LightningModule`` using the ``torch.export.export()`` API.

.. testcode:: python

import torch
from torch.export import export

class SimpleModel(LightningModule):
def __init__(self):
super().__init__()
Expand All @@ -25,25 +30,27 @@ can save or directly use.
return torch.relu(self.l1(x.view(x.size(0), -1)))


# create the model
# create the model and example input
model = SimpleModel()
script = model.to_torchscript()
example_input = torch.randn(1, 64)

# save for use in production environment
torch.jit.save(script, "model.pt")
# export the model
exported_program = export(model, (example_input,))

It is recommended that you install the latest supported version of PyTorch to use this feature without limitations.
# save for use in production environment
torch.export.save(exported_program, "model.pt2")

Once you have the exported model, you can run it in PyTorch or C++ runtime:
It is recommended that you install the latest supported version of PyTorch to use this feature without
limitations. Once you have the exported model, you can load and run it:

.. code-block:: python

inp = torch.rand(1, 64)
scripted_module = torch.jit.load("model.pt")
output = scripted_module(inp)
loaded_program = torch.export.load("model.pt2")
output = loaded_program.module()(inp)


If you want to script a different method, you can decorate the method with :func:`torch.jit.export`:
For more complex models, you can also export specific methods by creating a wrapper:

.. code-block:: python

Expand All @@ -54,7 +61,6 @@ If you want to script a different method, you can decorate the method with :func
self.dropout = nn.Dropout()
self.mc_iteration = mc_iteration

@torch.jit.export
def predict_step(self, batch, batch_idx):
# enable Monte Carlo Dropout
self.dropout.train()
Expand All @@ -66,4 +72,11 @@ If you want to script a different method, you can decorate the method with :func


model = LitMCdropoutModel(...)
script = model.to_torchscript(file_path="model.pt", method="script")
example_batch = torch.randn(32, 10) # example input

# Export the predict_step method
exported_program = torch.export.export(
lambda batch, idx: model.predict_step(batch, idx),
(example_batch, 0)
)
torch.export.save(exported_program, "mc_dropout_model.pt2")