From 8c6698e57b32eedcfc4e9d1cbe2e4c8d95991c5c Mon Sep 17 00:00:00 2001 From: Andreas Kirsch Date: Tue, 6 Feb 2024 16:09:30 +0000 Subject: [PATCH 1/4] Delete stub forward() method in Module The current stub breaks TorchScript/`to_torchscript` with: ``` otSupportedError: Compiled functions can't take variable number of arguments or use keyword-only arguments with defaults: File "/opt/conda/envs/image-gen/lib/python3.11/site-packages/pytorch_lightning/core/module.py", line 657 def forward(self, *args: Any, **kwargs: Any) -> Any: ~~~~~~~ <--- HERE r"""Same as :meth:`torch.nn.Module.forward`. ``` This becomes an issue when you use a module as container for other modules with bespoke methods and do not rely on the default forward call. --- src/lightning/pytorch/core/module.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index faaab7e15fd05..44092bc77b1cf 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -676,20 +676,6 @@ def all_gather( data = convert_to_tensors(data, device=self.device) return apply_to_collection(data, Tensor, all_gather, group=group, sync_grads=sync_grads) - @override - def forward(self, *args: Any, **kwargs: Any) -> Any: - r"""Same as :meth:`torch.nn.Module.forward`. - - Args: - *args: Whatever you decide to pass into the forward method. - **kwargs: Keyword arguments are also possible. - - Return: - Your model's output - - """ - return super().forward(*args, **kwargs) - def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or logger. From 35a8b49ccb65db5d91b2fc666eb7f2da4fd9d10e Mon Sep 17 00:00:00 2001 From: Andreas Kirsch Date: Tue, 6 Feb 2024 16:12:43 +0000 Subject: [PATCH 2/4] Add a test against regressions. --- tests/tests_pytorch/models/test_torchscript.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/tests_pytorch/models/test_torchscript.py b/tests/tests_pytorch/models/test_torchscript.py index 806b4db998af1..e030b0d67cff2 100644 --- a/tests/tests_pytorch/models/test_torchscript.py +++ b/tests/tests_pytorch/models/test_torchscript.py @@ -26,6 +26,14 @@ from tests_pytorch.helpers.runif import RunIf +def test_torchscript_vanilla(): + """Test that LightningModule itself can be converted.""" + model = LightningModule() + + script = model.to_torchscript() + assert isinstance(script, torch.jit.ScriptModule) + + @pytest.mark.parametrize("modelclass", [BoringModel, ParityModuleRNN, BasicGAN]) def test_torchscript_input_output(modelclass): """Test that scripted LightningModule forward works.""" From 207cdb10e359e41519be6a9a6cc6c7a995977a0d Mon Sep 17 00:00:00 2001 From: Andreas Kirsch Date: Sun, 11 Feb 2024 10:33:39 +0000 Subject: [PATCH 3/4] Only fix signature --- src/lightning/pytorch/core/module.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 44092bc77b1cf..93a9657a34459 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -675,6 +675,17 @@ def all_gather( all_gather = self.trainer.strategy.all_gather data = convert_to_tensors(data, device=self.device) return apply_to_collection(data, Tensor, all_gather, group=group, sync_grads=sync_grads) + + @override + def forward(self, *inputs) -> Any: + r"""Same as :meth:`torch.nn.Module.forward`. + Args: + *args: Whatever you decide to pass into the forward method. + **kwargs: Keyword arguments are also possible. + Return: + Your model's output + """ + return super().forward(*inputs) def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: r"""Here you compute and return the training loss and some additional metrics for e.g. the progress bar or From b1d75d01e98203fdb1e363b7885e0e4a2cf9ed3d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 11 Feb 2024 10:34:34 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/lightning/pytorch/core/module.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lightning/pytorch/core/module.py b/src/lightning/pytorch/core/module.py index 93a9657a34459..05af488b1229f 100644 --- a/src/lightning/pytorch/core/module.py +++ b/src/lightning/pytorch/core/module.py @@ -675,15 +675,17 @@ def all_gather( all_gather = self.trainer.strategy.all_gather data = convert_to_tensors(data, device=self.device) return apply_to_collection(data, Tensor, all_gather, group=group, sync_grads=sync_grads) - + @override def forward(self, *inputs) -> Any: r"""Same as :meth:`torch.nn.Module.forward`. + Args: *args: Whatever you decide to pass into the forward method. **kwargs: Keyword arguments are also possible. Return: Your model's output + """ return super().forward(*inputs)