@@ -99,7 +99,7 @@ def call(self, *args, **kwargs):
9999 tensors = args
100100 shapes = ((1 ,) + shape [1 :] for shape in map (ops .shape , tensors ))
101101 if self .enable_ebops :
102- ebops = self ._compute_ebops (* shapes )
102+ ebops = self ._compute_ebops (* shapes ) * self . ebops_factor
103103 self ._ebops .assign (ops .cast (ebops , self ._ebops .dtype ))
104104 self .add_loss (ebops * self .beta )
105105 if not self .enable_oq or self .__output_quantizer_handled__ :
@@ -136,6 +136,7 @@ def __init__(
136136 enable_oq : bool | None = None ,
137137 enable_iq : bool | None = None ,
138138 oq_conf : QuantizerConfig | None = None ,
139+ ebops_factor : float = 1.0 ,
139140 ** kwargs ,
140141 ):
141142 super ().__init__ (** kwargs )
@@ -148,6 +149,7 @@ def __init__(
148149 enable_ebops = global_config ['enable_ebops' ] and self .enable_iq
149150 self ._enable_ebops = enable_ebops
150151 self ._beta0 = beta0
152+ self .ebops_factor = ebops_factor
151153
152154 if self .enable_oq :
153155 oq_conf = oq_conf or QuantizerConfig ('default' , 'datalane' )
@@ -219,6 +221,7 @@ def get_config(self):
219221 config .update (
220222 {
221223 'enable_ebops' : self .enable_ebops ,
224+ 'ebops_factor' : self .ebops_factor ,
222225 'beta0' : self ._beta0 ,
223226 'enable_oq' : self .enable_oq ,
224227 'enable_iq' : self .enable_iq ,
0 commit comments