Skip to content

Commit b643896

Browse files
Add capability to deal with linear (Conv, Gemm...) without bias
The current version of DeepQuant simply assumes that all linear modules have a bias, otherwise it skips unifying the Dequant nodes. This modification enables unification of Dequant blocks even when there is no `biasDequantNode`. This implementation is incomplete as it assumes that the input Dequant zeroPoint is 0.
1 parent 98c828c commit b643896

File tree

1 file changed

+55
-30
lines changed

1 file changed

+55
-30
lines changed

DeepQuant/QuantManipulation/DequantModifier.py

Lines changed: 55 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
# Federico Brancasi <[email protected]>
66

77
import torch.fx as fx
8+
import torch
89

910
from DeepQuant.QuantManipulation.QuantDequantNodes import Dequant
1011
from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc
@@ -31,7 +32,10 @@ def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.Grap
3132
newLinArgs = []
3233

3334
for arg in oldArgs:
34-
if arg.op == "call_module" and "dequant" in arg.target.lower():
35+
# FCONTI: there is no Bias, propagate this to the newLinArgs
36+
if arg is None:
37+
newLinArgs.append(arg)
38+
elif arg.op == "call_module" and "dequant" in arg.target.lower():
3539
if "bias_dequant" in arg.target.lower():
3640
biasDequantNode = arg
3741
elif "weight_dequant" in arg.target.lower():
@@ -47,26 +51,48 @@ def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.Grap
4751
node.args = tuple(newLinArgs)
4852

4953
if biasDequantNode is None:
50-
# FBRANCASI: This would be unusual if a linear is missing bias or missing a bias_dequant
54+
# FCONTI: this happens if a linear layer has no bias
5155
if debug:
52-
print(f"Skipping {node.target}: no biasDequantNode found.")
53-
continue
54-
55-
biasQuantNode = biasDequantNode.args[0]
56-
if (
57-
biasQuantNode.op == "call_module"
58-
and "bias_quant" in biasQuantNode.target.lower()
59-
):
60-
newBqArgs = list(biasQuantNode.args)
61-
for i, bqArg in enumerate(newBqArgs):
62-
if bqArg.op == "call_module" and "dequant" in bqArg.target.lower():
63-
newBqArgs[i] = bqArg.args[0]
64-
biasQuantNode.args = tuple(newBqArgs)
56+
print(f"Skipping bias for {node.target}: no biasDequantNode found.")
57+
biasQuantNode = None
6558
else:
66-
if debug:
67-
print(
68-
"Warning: Did not find a typical 'bias_quant' node shape in the graph."
69-
)
59+
biasQuantNode = biasDequantNode.args[0]
60+
if (
61+
biasQuantNode.op == "call_module"
62+
and "bias_quant" in biasQuantNode.target.lower()
63+
):
64+
newBqArgs = list(biasQuantNode.args)
65+
for i, bqArg in enumerate(newBqArgs):
66+
if bqArg.op == "call_module" and "dequant" in bqArg.target.lower():
67+
newBqArgs[i] = bqArg.args[0]
68+
biasQuantNode.args = tuple(newBqArgs)
69+
else:
70+
if debug:
71+
print(
72+
"Warning: Did not find a typical 'bias_quant' node shape in the graph."
73+
)
74+
75+
# FCONTI: if there is a bias node, use it for scale/zeropoint/bitwidth.
76+
# otherwise, rely on weight*input
77+
if biasDequantNode is not None:
78+
oldBiasDequantMod = fxModel.get_submodule(biasDequantNode.target)
79+
dequantScale = oldBiasDequantMod.scale
80+
dequantZeroPoint = oldBiasDequantMod.zeroPoint
81+
dequantBitWidth = oldBiasDequantMod.bitWidth
82+
oldDequantMod = oldBiasDequantMod
83+
else:
84+
oldInputDequantMod = fxModel.get_submodule(inputDequantNode.target)
85+
oldWeightDequantMod = fxModel.get_submodule(weightDequantNode.target)
86+
dequantScale = oldWeightDequantMod.scale * oldInputDequantMod.scale
87+
# FCONTI: technically it should be:
88+
# dZP = oWDM.zP * oIDM.zP - oWDM.scale * oIDM.zP * sum(weights)
89+
# how to appropriately compute sum(weights)?
90+
# for now we restrict ourselves to oIDM.zP = 0, so dZP = 0
91+
if debug and oldInputDequantMod.zeroPoint != 0.0:
92+
print(f"Warning: input Dequant node for {node.target} has non-zero zero-point (unsupported). Expect wrong results!")
93+
dequantZeroPoint = 0.0
94+
dequantBitWidth = 32 # FCONTI: this is simply a reasonable assumption: is there a less arbitrary one?
95+
oldDequantMod = oldWeightDequantMod
7096

7197
for dnode in (inputDequantNode, weightDequantNode):
7298
if dnode is not None:
@@ -76,19 +102,17 @@ def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.Grap
76102
delattr(fxModel, dnode.target)
77103
graph.erase_node(dnode)
78104

79-
oldBiasDequantMod = fxModel.get_submodule(biasDequantNode.target)
80-
81105
newDequantModName = (
82106
node.target.replace(".wrappedInnerForwardImpl", "") + "_unified_dequant"
83107
)
84108
# JUNGVI: Torch modules name cannot contain "."
85109
newDequantModName = newDequantModName.replace(".", "_")
86110

87111
unifiedDequantMod = Dequant(
88-
originalModule=oldBiasDequantMod.originalModule,
89-
scale=oldBiasDequantMod.scale,
90-
zeroPoint=oldBiasDequantMod.zeroPoint,
91-
bitWidth=oldBiasDequantMod.bitWidth,
112+
originalModule=oldDequantMod.originalModule,
113+
scale=dequantScale,
114+
zeroPoint=dequantZeroPoint,
115+
bitWidth=dequantBitWidth,
92116
)
93117

94118
fxModel.add_module(newDequantModName, unifiedDequantMod)
@@ -105,11 +129,12 @@ def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.Grap
105129
newArgs[i] = newDequantNode
106130
usr.args = tuple(newArgs)
107131

108-
for usr in list(biasDequantNode.users.keys()):
109-
biasDequantNode.users[usr] = None
110-
if hasattr(fxModel, biasDequantNode.target):
111-
delattr(fxModel, biasDequantNode.target)
112-
graph.erase_node(biasDequantNode)
132+
if biasDequantNode is not None:
133+
for usr in list(biasDequantNode.users.keys()):
134+
biasDequantNode.users[usr] = None
135+
if hasattr(fxModel, biasDequantNode.target):
136+
delattr(fxModel, biasDequantNode.target)
137+
graph.erase_node(biasDequantNode)
113138

114139
if debug:
115140
print(cc.success(f"Modification done for {node.target}"))

0 commit comments

Comments
 (0)