1616
1717import pytest
1818import torch
19+ from torchmetrics .functional import mean_absolute_percentage_error as mape
1920
2021from pytorch_lightning import seed_everything , Trainer
2122from pytorch_lightning .callbacks import QuantizationAwareTraining
22- from pytorch_lightning .metrics .functional .mean_relative_error import mean_relative_error
2323from pytorch_lightning .utilities .exceptions import MisconfigurationException
2424from tests .helpers .datamodules import RegressDataModule
2525from tests .helpers .runif import RunIf
@@ -41,7 +41,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
4141 trainer = Trainer (** trainer_args )
4242 trainer .fit (model , datamodule = dm )
4343 org_size = model .model_size
44- org_score = torch .mean (torch .tensor ([mean_relative_error (model (x ), y ) for x , y in dm .test_dataloader ()]))
44+ org_score = torch .mean (torch .tensor ([mape (model (x ), y ) for x , y in dm .test_dataloader ()]))
4545
4646 fusing_layers = [(f"layer_{ i } " , f"layer_{ i } a" ) for i in range (3 )] if fuse else None
4747 qcb = QuantizationAwareTraining (observer_type = observe , modules_to_fuse = fusing_layers , quantize_on_fit_end = convert )
@@ -50,7 +50,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
5050
5151 quant_calls = qcb ._forward_calls
5252 assert quant_calls == qcb ._forward_calls
53- quant_score = torch .mean (torch .tensor ([mean_relative_error (qmodel (x ), y ) for x , y in dm .test_dataloader ()]))
53+ quant_score = torch .mean (torch .tensor ([mape (qmodel (x ), y ) for x , y in dm .test_dataloader ()]))
5454 # test that the test score is almost the same as with pure training
5555 assert torch .allclose (org_score , quant_score , atol = 0.45 )
5656 model_path = trainer .checkpoint_callback .best_model_path
@@ -69,7 +69,7 @@ def test_quantization(tmpdir, observe: str, fuse: bool, convert: bool):
6969
7070 # todo: make it work also with strict loading
7171 qmodel2 = RegressionModel .load_from_checkpoint (model_path , strict = False )
72- quant2_score = torch .mean (torch .tensor ([mean_relative_error (qmodel2 (x ), y ) for x , y in dm .test_dataloader ()]))
72+ quant2_score = torch .mean (torch .tensor ([mape (qmodel2 (x ), y ) for x , y in dm .test_dataloader ()]))
7373 assert torch .allclose (org_score , quant2_score , atol = 0.45 )
7474
7575
0 commit comments