Skip to content

Commit a8b3c1f

Browse files
author
Enrico Lupi
committed
ADD support for PQBatchNorm1d in pytorch
1 parent 0805039 commit a8b3c1f

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

hls4ml/converters/pytorch/pquant.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,8 @@ def handler(operation, layer_name, input_names, input_shapes, node, class_object
111111

112112

113113
parse_pqlinear_layer = make_pquant_handler(parse_linear_layer, 'PQDense', 'PQLinear')
114-
parse_pqbatchnorm_layer = make_pquant_handler(parse_batchnorm_layer, 'PQBatchNorm2d')
114+
parse_pqbatchnorm1d_layer = make_pquant_handler(parse_batchnorm_layer, 'PQBatchNorm1d')
115+
parse_pqbatchnorm2d_layer = make_pquant_handler(parse_batchnorm_layer, 'PQBatchNorm2d')
115116
parse_pqconv1d_layer = make_pquant_handler(parse_conv1d_layer, 'PQConv1d')
116117
parse_pqconv2d_layer = make_pquant_handler(parse_conv2d_layer, 'PQConv2d')
117118
parse_pqpool1d_layer = make_pquant_handler(parse_pooling_layer, 'PQAvgPool1d', 'AvgPool1d')

test/pytest/test_pquant_pytorch.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,17 @@
33

44
import numpy as np
55
import pytest
6-
from pquant.activations import PQActivation
7-
from pquant.core.finetuning import TuningConfig
8-
from pquant.core.utils import get_default_config
9-
from pquant.layers import PQAvgPool1d, PQAvgPool2d, PQBatchNorm2d, PQConv1d, PQConv2d, PQDense
106

117
from hls4ml.converters import convert_from_pytorch_model
128
from hls4ml.utils import config_from_pytorch_model
139

1410
os.environ['KERAS_BACKEND'] = 'torch'
1511
import torch # noqa: E402
1612
import torch.nn as nn # noqa: E402
13+
from pquant.activations import PQActivation # noqa: E402
14+
from pquant.core.finetuning import TuningConfig # noqa: E402
15+
from pquant.core.utils import get_default_config # noqa: E402
16+
from pquant.layers import PQAvgPool1d, PQAvgPool2d, PQBatchNorm1d, PQBatchNorm2d, PQConv1d, PQConv2d, PQDense # noqa: E402
1717

1818
test_path = Path(__file__).parent
1919

@@ -125,9 +125,9 @@ def get_shape(model: nn.Module, batch_size: int = 1, default_length: int = 32, d
125125
case PQAvgPool2d():
126126
# (N, C, H, W)
127127
return (batch_size, 1, *default_hw)
128-
# case PQBatchNorm1d():
129-
# # (N, num_features, L)
130-
# return (batch_size, layer.num_features, *default_length)
128+
case PQBatchNorm1d():
129+
# (N, num_features, L)
130+
return (batch_size, layer.num_features, default_length)
131131
case PQBatchNorm2d():
132132
# (N, num_features, H, W)
133133
return (batch_size, layer.num_features, *default_hw)
@@ -159,6 +159,7 @@ def get_shape(model: nn.Module, batch_size: int = 1, default_length: int = 32, d
159159
'PQConv2d(2, 3, kernel_size=(3,3), padding=0, bias=False)',
160160
'PQConv2d(2, 3, kernel_size=(3,3), padding=0, stride=2)',
161161
'PQConv2d(2, 3, kernel_size=(3,3), padding=1, stride=2)',
162+
'PQBatchNorm1d(3)',
162163
'PQBatchNorm2d(3)',
163164
'PQAvgPool1d(2, padding=1)',
164165
'PQAvgPool1d(2, padding=0)',

0 commit comments

Comments
 (0)