Skip to content

Commit c7e3f2a

Browse files
authored
Fix (nn/bias): propagate runtime_shape from QuantScaleBias (#1385)
1 parent 54db11f commit c7e3f2a

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

src/brevitas/nn/quant_scale_bias.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# SPDX-License-Identifier: BSD-3-Clause
33

44
from typing import Optional
5+
from typing import Tuple
56
from typing import Type
67
from typing import Union
78

@@ -25,7 +26,8 @@
2526

2627
class ScaleBias(Module):
2728

28-
def __init__(self, num_features: int, bias: bool, runtime_shape=(1, -1, 1, 1)):
29+
def __init__(
30+
self, num_features: int, bias: bool, runtime_shape: Tuple[int, ...] = (1, -1, 1, 1)):
2931
super(ScaleBias, self).__init__()
3032
self.num_features = num_features
3133
self.weight = Parameter(torch.ones(num_features))
@@ -49,9 +51,10 @@ def __init__(
4951
bias_quant: Optional[BiasQuantType] = None,
5052
input_quant: Optional[ActQuantType] = None,
5153
output_quant: Optional[ActQuantType] = None,
54+
runtime_shape: Tuple[int, ...] = (1, -1, 1, 1),
5255
return_quant_tensor: bool = False,
5356
**kwargs) -> None:
54-
ScaleBias.__init__(self, num_features, bias)
57+
ScaleBias.__init__(self, num_features, bias, runtime_shape)
5558
QuantWBIOL.__init__(
5659
self,
5760
weight_quant=weight_quant,

tests/brevitas/nn/test_wbiol.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from brevitas.nn import QuantConvTranspose2d
1313
from brevitas.nn import QuantConvTranspose3d
1414
from brevitas.nn import QuantLinear
15-
from brevitas.nn import QuantScaleBias
1615
from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL
1716
from brevitas.proxy import ActQuantProxyFromInjector
1817
from brevitas.proxy import BiasQuantProxyFromInjector

0 commit comments

Comments
 (0)