Skip to content

Commit 493445c

Browse files
authored
Fix (graph/qronos): Normalize contribution to H and G when buffer is disabled (#1440)
1 parent c520d81 commit 493445c

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

src/brevitas/graph/qronos.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,19 +62,22 @@ def update_batch(self, module, input, current_layer):
6262
if not is_quant_enabled:
6363
# Computing the normalized G matrix
6464
self.G *= (self.nsamples - batch_size) / self.nsamples
65+
inp_processed /= math.sqrt(
66+
self.nsamples) # NOTE: quant_input is normalized before, in the H update
6567
if self.use_intermediate_buffer:
6668
self.B.copy_(inp_processed.bmm(self.quant_input.transpose(2, 1)))
67-
self.G += (self.B / self.nsamples)
69+
self.G += self.B
6870
else:
6971
self.G += inp_processed.bmm(self.quant_input.transpose(2, 1))
7072
self.quant_input = None # NOTE: set back to None now that we've used it
7173
else:
7274
# Computing the normalized H matrix
7375
self.nsamples += batch_size # NOTE: only increment with quant inputs
7476
self.H *= (self.nsamples - batch_size) / self.nsamples
77+
inp_processed /= math.sqrt(self.nsamples)
7578
if self.use_intermediate_buffer:
7679
self.B.copy_(inp_processed.bmm(inp_processed.transpose(2, 1)))
77-
self.H += (self.B / self.nsamples)
80+
self.H += self.B
7881
else:
7982
self.H += inp_processed.bmm(inp_processed.transpose(2, 1))
8083
# store the quantized input for computing the H matrix

0 commit comments

Comments
 (0)