22# SPDX-License-Identifier: BSD-3-Clause
33
44from typing import Optional
5+ from typing import Tuple
56from typing import Type
67from typing import Union
78
2526
2627class ScaleBias (Module ):
2728
28- def __init__ (self , num_features : int , bias : bool , runtime_shape = (1 , - 1 , 1 , 1 )):
29+ def __init__ (
30+ self , num_features : int , bias : bool , runtime_shape : Tuple [int , ...] = (1 , - 1 , 1 , 1 )):
2931 super (ScaleBias , self ).__init__ ()
3032 self .num_features = num_features
3133 self .weight = Parameter (torch .ones (num_features ))
@@ -49,9 +51,10 @@ def __init__(
4951 bias_quant : Optional [BiasQuantType ] = None ,
5052 input_quant : Optional [ActQuantType ] = None ,
5153 output_quant : Optional [ActQuantType ] = None ,
54+ runtime_shape : Tuple [int , ...] = (1 , - 1 , 1 , 1 ),
5255 return_quant_tensor : bool = False ,
5356 ** kwargs ) -> None :
54- ScaleBias .__init__ (self , num_features , bias )
57+ ScaleBias .__init__ (self , num_features , bias , runtime_shape )
5558 QuantWBIOL .__init__ (
5659 self ,
5760 weight_quant = weight_quant ,
0 commit comments