Skip to content

Commit 7a04955

Browse files
author
Enrico Lupi
committed
ADD test for PQBatchNorm1d
1 parent 0805039 commit 7a04955

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed

test/pytest/test_pquant_pytorch.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,17 @@
33

44
import numpy as np
55
import pytest
6-
from pquant.activations import PQActivation
7-
from pquant.core.finetuning import TuningConfig
8-
from pquant.core.utils import get_default_config
9-
from pquant.layers import PQAvgPool1d, PQAvgPool2d, PQBatchNorm2d, PQConv1d, PQConv2d, PQDense
106

117
from hls4ml.converters import convert_from_pytorch_model
128
from hls4ml.utils import config_from_pytorch_model
139

1410
os.environ['KERAS_BACKEND'] = 'torch'
1511
import torch # noqa: E402
1612
import torch.nn as nn # noqa: E402
13+
from pquant.activations import PQActivation # noqa: E402
14+
from pquant.core.finetuning import TuningConfig # noqa: E402
15+
from pquant.core.utils import get_default_config # noqa: E402
16+
from pquant.layers import PQAvgPool1d, PQAvgPool2d, PQBatchNorm1d, PQBatchNorm2d, PQConv1d, PQConv2d, PQDense # noqa: E402
1717

1818
test_path = Path(__file__).parent
1919

@@ -125,9 +125,9 @@ def get_shape(model: nn.Module, batch_size: int = 1, default_length: int = 32, d
125125
case PQAvgPool2d():
126126
# (N, C, H, W)
127127
return (batch_size, 1, *default_hw)
128-
# case PQBatchNorm1d():
129-
# # (N, num_features, L)
130-
# return (batch_size, layer.num_features, *default_length)
128+
case PQBatchNorm1d():
129+
# (N, num_features, L)
130+
return (batch_size, layer.num_features, default_length)
131131
case PQBatchNorm2d():
132132
# (N, num_features, H, W)
133133
return (batch_size, layer.num_features, *default_hw)
@@ -159,6 +159,7 @@ def get_shape(model: nn.Module, batch_size: int = 1, default_length: int = 32, d
159159
'PQConv2d(2, 3, kernel_size=(3,3), padding=0, bias=False)',
160160
'PQConv2d(2, 3, kernel_size=(3,3), padding=0, stride=2)',
161161
'PQConv2d(2, 3, kernel_size=(3,3), padding=1, stride=2)',
162+
'PQBatchNorm1d(3)',
162163
'PQBatchNorm2d(3)',
163164
'PQAvgPool1d(2, padding=1)',
164165
'PQAvgPool1d(2, padding=0)',

0 commit comments

Comments
 (0)