Skip to content

Commit e8a6e02

Browse files
ananthsubBorda
andcommitted
Avoid torchscript export for Metric forward (#4428)
* Update metric.py * add test * Update CHANGELOG.md * Update test_metric_lightning.py * Update test_metric_lightning.py Co-authored-by: Jirka Borovec <[email protected]> (cherry picked from commit 5d08559)
1 parent 38f4a83 commit e8a6e02

File tree

3 files changed

+43
-2
lines changed

3 files changed

+43
-2
lines changed

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1919

2020
- Added timeout for `tpu_device_exists` to ensure process does not hang indefinitely ([#4340](https://github.com/PyTorchLightning/pytorch-lightning/pull/4340))
2121

22-
- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))
22+
- Added global step indexing to the checkpoint name for a better sub-epoch checkpointing experience ([#3807](https://github.com/PyTorchLightning/pytorch-lightning/pull/3807))
2323

2424
### Changed
2525

@@ -47,6 +47,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4747

4848
- Fixed AMP unscale for `on_after_backward` ([#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
4949

50+
- Fixed TorchScript export when module includes Metrics ([#4428](https://github.com/PyTorchLightning/pytorch-lightning/pull/4428))
5051

5152
## [1.0.4] - 2020-10-27
5253

pytorch_lightning/metrics/metric.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def add_state(
145145
self._defaults[name] = deepcopy(default)
146146
self._reductions[name] = dist_reduce_fx
147147

148+
@torch.jit.unused
148149
def forward(self, *args, **kwargs):
149150
"""
150151
Automatically calls ``update()``. Returns the metric value over inputs if ``compute_on_step`` is True.

tests/metrics/test_metric_lightning.py

Lines changed: 40 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
import torch
1+
import os
22

3+
import torch
34
from pytorch_lightning import Trainer
45
from pytorch_lightning.metrics import Metric
56
from tests.base.boring_model import BoringModel
@@ -78,3 +79,41 @@ def training_step(self, batch, batch_idx):
7879

7980
logged = trainer.logged_metrics
8081
assert torch.allclose(torch.tensor(logged["sum"]), model.sum)
82+
83+
84+
def test_scriptable(tmpdir):
85+
class TestModel(BoringModel):
86+
def __init__(self):
87+
super().__init__()
88+
# the metric is not used in the module's `forward`
89+
# so the module should be exportable to TorchScript
90+
self.metric = SumMetric()
91+
self.sum = 0.0
92+
93+
def training_step(self, batch, batch_idx):
94+
x = batch
95+
self.metric(x.sum())
96+
self.sum += x.sum()
97+
self.log("sum", self.metric, on_epoch=True, on_step=False)
98+
return self.step(x)
99+
100+
model = TestModel()
101+
trainer = Trainer(
102+
default_root_dir=tmpdir,
103+
limit_train_batches=2,
104+
limit_val_batches=2,
105+
max_epochs=1,
106+
log_every_n_steps=1,
107+
weights_summary=None,
108+
logger=False,
109+
checkpoint_callback=False,
110+
)
111+
trainer.fit(model)
112+
rand_input = torch.randn(10, 32)
113+
114+
script_model = model.to_torchscript()
115+
116+
# test that we can still do inference
117+
output = model(rand_input)
118+
script_output = script_model(rand_input)
119+
assert torch.allclose(output, script_output)

0 commit comments

Comments
 (0)