diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 85f631ee40f75..78d805ef5074f 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -694,18 +694,17 @@ def all_gather( return apply_to_collection(data, Tensor, all_gather, group=group, sync_grads=sync_grads) @override - def forward(self, *args: Any, **kwargs: Any) -> Any: + def forward(self, *inputs) -> Any: r"""Same as :meth:`torch.nn.Module.forward`. Args: *args: Whatever you decide to pass into the forward method. **kwargs: Keyword arguments are also possible. - Return: Your model's output """ - return super().forward(*args, **kwargs) + return super().forward(*inputs) def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py index 29f251044c0b5..781656d61b7ae 100644 --- a/tests/tests_pytorch/models/test_torchscript.py +++ b/tests/tests_pytorch/models/test_torchscript.py @@ -27,6 +27,14 @@ from tests_pytorch.helpers.runif import RunIf +def test_torchscript_vanilla(): + """Test that LightningModule itself can be converted.""" + model = LightningModule() + + script = model.to_torchscript() + assert isinstance(script, torch.jit.ScriptModule) + + @pytest.mark.skipif(_IS_WINDOWS and _TORCH_GREATER_EQUAL_2_4, reason="not close on Windows + PyTorch 2.4") @pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN]) def test_torchscript_input_output(modelclass):