Skip to content

Commit dff47e7

Browse files
Bordalexierule
authored andcommitted
Tests: fix deprecated TM mape (#8830)
(cherry picked from commit 3096ab8)
1 parent a4283c2 commit dff47e7

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

tests/callbacks/test_quantization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,10 @@
1616

1717
import pytest
1818
import torch
19+
from torchmetrics.functional import mean_absolute_percentage_error as mape
1920

2021
from pytorch_lightning import seed_everything, Trainer
2122
from pytorch_lightning.callbacks import QuantizationAwareTraining
22-
from pytorch_lightning.metrics.functional.mean_relative_error import mean_relative_error
2323
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2424
from tests.helpers.datamodules import RegressDataModule
2525
from tests.helpers.runif import RunIf
@@ -41,7 +41,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
4141
trainer = Trainer(**trainer_args)
4242
trainer.fit(model, datamodule=dm)
4343
org_size = model.model_size
44-
org_score = torch.mean(torch.tensor([mean_relative_error(model(x), y) for x, y in dm.test_dataloader()]))
44+
org_score = torch.mean(torch.tensor([mape(model(x), y) for x, y in dm.test_dataloader()]))
4545

4646
fusing_layers = [(f"layer_{i}", f"layer_{i}a") for i in range(3)] if fuse else None
4747
qcb = QuantizationAwareTraining(observer_type=observe, modules_to_fuse=fusing_layers, quantize_on_fit_end=convert)
@@ -50,7 +50,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
5050

5151
quant_calls = qcb._forward_calls
5252
assert quant_calls == qcb._forward_calls
53-
quant_score = torch.mean(torch.tensor([mean_relative_error(qmodel(x), y) for x, y in dm.test_dataloader()]))
53+
quant_score = torch.mean(torch.tensor([mape(qmodel(x), y) for x, y in dm.test_dataloader()]))
5454
# test that the test score is almost the same as with pure training
5555
assert torch.allclose(org_score, quant_score, atol=0.45)
5656
model_path = trainer.checkpoint_callback.best_model_path
@@ -69,7 +69,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
6969

7070
# todo: make it work also with strict loading
7171
qmodel2 = RegressionModel.load_from_checkpoint(model_path, strict=False)
72-
quant2_score = torch.mean(torch.tensor([mean_relative_error(qmodel2(x), y) for x, y in dm.test_dataloader()]))
72+
quant2_score = torch.mean(torch.tensor([mape(qmodel2(x), y) for x, y in dm.test_dataloader()]))
7373
assert torch.allclose(org_score, quant2_score, atol=0.45)
7474

7575

0 commit comments

Comments
 (0)