Skip to content

Commit d9cda6a

Browse files
committed
feat(ebops): allow per-layer ebops scaling factor
1 parent 39970c9 commit d9cda6a

File tree

1 file changed

+4
-1
lines changed

1 file changed

+4
-1
lines changed

src/hgq/layers/core/base.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def call(self, *args, **kwargs):
9999
tensors = args
100100
shapes = ((1,) + shape[1:] for shape in map(ops.shape, tensors))
101101
if self.enable_ebops:
102-
ebops = self._compute_ebops(*shapes)
102+
ebops = self._compute_ebops(*shapes) * self.ebops_factor
103103
self._ebops.assign(ops.cast(ebops, self._ebops.dtype))
104104
self.add_loss(ebops * self.beta)
105105
if not self.enable_oq or self.__output_quantizer_handled__:
@@ -136,6 +136,7 @@ def __init__(
136136
enable_oq: bool | None = None,
137137
enable_iq: bool | None = None,
138138
oq_conf: QuantizerConfig | None = None,
139+
ebops_factor: float = 1.0,
139140
**kwargs,
140141
):
141142
super().__init__(**kwargs)
@@ -148,6 +149,7 @@ def __init__(
148149
enable_ebops = global_config['enable_ebops'] and self.enable_iq
149150
self._enable_ebops = enable_ebops
150151
self._beta0 = beta0
152+
self.ebops_factor = ebops_factor
151153

152154
if self.enable_oq:
153155
oq_conf = oq_conf or QuantizerConfig('default', 'datalane')
@@ -219,6 +221,7 @@ def get_config(self):
219221
config.update(
220222
{
221223
'enable_ebops': self.enable_ebops,
224+
'ebops_factor': self.ebops_factor,
222225
'beta0': self._beta0,
223226
'enable_oq': self.enable_oq,
224227
'enable_iq': self.enable_iq,

0 commit comments

Comments
 (0)