|
3 | 3 |
|
4 | 4 | import numpy as np |
5 | 5 | import pytest |
6 | | -from pquant.activations import PQActivation |
7 | | -from pquant.core.finetuning import TuningConfig |
8 | | -from pquant.core.utils import get_default_config |
9 | | -from pquant.layers import PQAvgPool1d, PQAvgPool2d, PQBatchNorm2d, PQConv1d, PQConv2d, PQDense |
10 | 6 |
|
11 | 7 | from hls4ml.converters import convert_from_pytorch_model |
12 | 8 | from hls4ml.utils import config_from_pytorch_model |
13 | 9 |
|
14 | 10 | os.environ['KERAS_BACKEND'] = 'torch' |
15 | 11 | import torch # noqa: E402 |
16 | 12 | import torch.nn as nn # noqa: E402 |
| 13 | +from pquant.activations import PQActivation # noqa: E402 |
| 14 | +from pquant.core.finetuning import TuningConfig # noqa: E402 |
| 15 | +from pquant.core.utils import get_default_config # noqa: E402 |
| 16 | +from pquant.layers import PQAvgPool1d, PQAvgPool2d, PQBatchNorm1d, PQBatchNorm2d, PQConv1d, PQConv2d, PQDense # noqa: E402 |
17 | 17 |
|
18 | 18 | test_path = Path(__file__).parent |
19 | 19 |
|
@@ -125,9 +125,9 @@ def get_shape(model: nn.Module, batch_size: int = 1, default_length: int = 32, d |
125 | 125 | case PQAvgPool2d(): |
126 | 126 | # (N, C, H, W) |
127 | 127 | return (batch_size, 1, *default_hw) |
128 | | - # case PQBatchNorm1d(): |
129 | | - # # (N, num_features, L) |
130 | | - # return (batch_size, layer.num_features, *default_length) |
| 128 | + case PQBatchNorm1d(): |
| 129 | + # (N, num_features, L) |
| 130 | + return (batch_size, layer.num_features, default_length) |
131 | 131 | case PQBatchNorm2d(): |
132 | 132 | # (N, num_features, H, W) |
133 | 133 | return (batch_size, layer.num_features, *default_hw) |
@@ -159,6 +159,7 @@ def get_shape(model: nn.Module, batch_size: int = 1, default_length: int = 32, d |
159 | 159 | 'PQConv2d(2, 3, kernel_size=(3,3), padding=0, bias=False)', |
160 | 160 | 'PQConv2d(2, 3, kernel_size=(3,3), padding=0, stride=2)', |
161 | 161 | 'PQConv2d(2, 3, kernel_size=(3,3), padding=1, stride=2)', |
| 162 | + 'PQBatchNorm1d(3)', |
162 | 163 | 'PQBatchNorm2d(3)', |
163 | 164 | 'PQAvgPool1d(2, padding=1)', |
164 | 165 | 'PQAvgPool1d(2, padding=0)', |
|
0 commit comments