Skip to content

Commit e2bc1b4

Browse files
Change Rounding Policy in Quant Module
1 parent cfd21c3 commit e2bc1b4

File tree

2 files changed

+10
-8
lines changed

2 files changed

+10
-8
lines changed

DeepQuant/QuantManipulation/DequantModifier.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
# Federico Brancasi <[email protected]>
66

77
import torch.fx as fx
8-
import torch
98

109
from DeepQuant.QuantManipulation.QuantDequantNodes import Dequant
1110
from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc
@@ -76,22 +75,24 @@ def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.Grap
7675
# otherwise, rely on weight*input
7776
if biasDequantNode is not None:
7877
oldBiasDequantMod = fxModel.get_submodule(biasDequantNode.target)
79-
dequantScale = oldBiasDequantMod.scale
78+
dequantScale = oldBiasDequantMod.scale
8079
dequantZeroPoint = oldBiasDequantMod.zeroPoint
81-
dequantBitWidth = oldBiasDequantMod.bitWidth
80+
dequantBitWidth = oldBiasDequantMod.bitWidth
8281
oldDequantMod = oldBiasDequantMod
8382
else:
84-
oldInputDequantMod = fxModel.get_submodule(inputDequantNode.target)
83+
oldInputDequantMod = fxModel.get_submodule(inputDequantNode.target)
8584
oldWeightDequantMod = fxModel.get_submodule(weightDequantNode.target)
86-
dequantScale = oldWeightDequantMod.scale * oldInputDequantMod.scale
85+
dequantScale = oldWeightDequantMod.scale * oldInputDequantMod.scale
8786
# FCONTI: technically it should be:
8887
# dZP = oWDM.zP * oIDM.zP - oWDM.scale * oIDM.zP * sum(weights)
8988
# how to appropriately compute sum(weights)?
9089
# for now we restrict ourselves to oIDM.zP = 0, so dZP = 0
9190
if debug and oldInputDequantMod.zeroPoint != 0.0:
92-
print(f"Warning: input Dequant node for {node.target} has non-zero zero-point (unsupported). Expect wrong results!")
91+
print(
92+
f"Warning: input Dequant node for {node.target} has non-zero zero-point (unsupported). Expect wrong results!"
93+
)
9394
dequantZeroPoint = 0.0
94-
dequantBitWidth = 32 # FCONTI: this is simply a reasonable assumption: is there a less arbitrary one?
95+
dequantBitWidth = 32 # FCONTI: this is simply a reasonable assumption: is there a less arbitrary one?
9596
oldDequantMod = oldWeightDequantMod
9697

9798
for dnode in (inputDequantNode, weightDequantNode):

DeepQuant/QuantManipulation/QuantDequantNodes.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
4747

4848
xScaled = x / self.scale
4949
xShifted = xScaled + self.zeroPoint
50-
xRounded = torch.round(xShifted)
50+
xRounded = torch.floor(xShifted + 0.5)
51+
5152
if self.bitWidth is not None:
5253
xRounded = torch.clamp(xRounded, self.minVal, self.maxVal)
5354
return xRounded

0 commit comments

Comments
 (0)