55# Federico Brancasi <[email protected] > 66
77import torch .fx as fx
8+ import torch
89
910from DeepQuant .QuantManipulation .QuantDequantNodes import Dequant
1011from DeepQuant .Utils .ConsoleFormatter import ConsoleColor as cc
@@ -31,7 +32,10 @@ def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.Grap
3132 newLinArgs = []
3233
3334 for arg in oldArgs :
34- if arg .op == "call_module" and "dequant" in arg .target .lower ():
35+ # FCONTI: there is no Bias, propagate this to the newLinArgs
36+ if arg is None :
37+ newLinArgs .append (arg )
38+ elif arg .op == "call_module" and "dequant" in arg .target .lower ():
3539 if "bias_dequant" in arg .target .lower ():
3640 biasDequantNode = arg
3741 elif "weight_dequant" in arg .target .lower ():
@@ -47,26 +51,48 @@ def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.Grap
4751 node .args = tuple (newLinArgs )
4852
4953 if biasDequantNode is None :
50- # FBRANCASI: This would be unusual if a linear is missing bias or missing a bias_dequant
54+ # FCONTI: this happens if a linear layer has no bias
5155 if debug :
52- print (f"Skipping { node .target } : no biasDequantNode found." )
53- continue
54-
55- biasQuantNode = biasDequantNode .args [0 ]
56- if (
57- biasQuantNode .op == "call_module"
58- and "bias_quant" in biasQuantNode .target .lower ()
59- ):
60- newBqArgs = list (biasQuantNode .args )
61- for i , bqArg in enumerate (newBqArgs ):
62- if bqArg .op == "call_module" and "dequant" in bqArg .target .lower ():
63- newBqArgs [i ] = bqArg .args [0 ]
64- biasQuantNode .args = tuple (newBqArgs )
56+ print (f"Skipping bias for { node .target } : no biasDequantNode found." )
57+ biasQuantNode = None
6558 else :
66- if debug :
67- print (
68- "Warning: Did not find a typical 'bias_quant' node shape in the graph."
69- )
59+ biasQuantNode = biasDequantNode .args [0 ]
60+ if (
61+ biasQuantNode .op == "call_module"
62+ and "bias_quant" in biasQuantNode .target .lower ()
63+ ):
64+ newBqArgs = list (biasQuantNode .args )
65+ for i , bqArg in enumerate (newBqArgs ):
66+ if bqArg .op == "call_module" and "dequant" in bqArg .target .lower ():
67+ newBqArgs [i ] = bqArg .args [0 ]
68+ biasQuantNode .args = tuple (newBqArgs )
69+ else :
70+ if debug :
71+ print (
72+ "Warning: Did not find a typical 'bias_quant' node shape in the graph."
73+ )
74+
75+ # FCONTI: if there is a bias node, use it for scale/zeropoint/bitwidth.
76+ # otherwise, rely on weight*input
77+ if biasDequantNode is not None :
78+ oldBiasDequantMod = fxModel .get_submodule (biasDequantNode .target )
79+ dequantScale = oldBiasDequantMod .scale
80+ dequantZeroPoint = oldBiasDequantMod .zeroPoint
81+ dequantBitWidth = oldBiasDequantMod .bitWidth
82+ oldDequantMod = oldBiasDequantMod
83+ else :
84+ oldInputDequantMod = fxModel .get_submodule (inputDequantNode .target )
85+ oldWeightDequantMod = fxModel .get_submodule (weightDequantNode .target )
86+ dequantScale = oldWeightDequantMod .scale * oldInputDequantMod .scale
87+ # FCONTI: technically it should be:
88+ # dZP = oWDM.zP * oIDM.zP - oWDM.scale * oIDM.zP * sum(weights)
89+ # how to appropriately compute sum(weights)?
90+ # for now we restrict ourselves to oIDM.zP = 0, so dZP = 0
91+ if debug and oldInputDequantMod .zeroPoint != 0.0 :
92+ print (f"Warning: input Dequant node for { node .target } has non-zero zero-point (unsupported). Expect wrong results!" )
93+ dequantZeroPoint = 0.0
94+ dequantBitWidth = 32 # FCONTI: this is simply a reasonable assumption: is there a less arbitrary one?
95+ oldDequantMod = oldWeightDequantMod
7096
7197 for dnode in (inputDequantNode , weightDequantNode ):
7298 if dnode is not None :
@@ -76,19 +102,17 @@ def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.Grap
76102 delattr (fxModel , dnode .target )
77103 graph .erase_node (dnode )
78104
79- oldBiasDequantMod = fxModel .get_submodule (biasDequantNode .target )
80-
81105 newDequantModName = (
82106 node .target .replace (".wrappedInnerForwardImpl" , "" ) + "_unified_dequant"
83107 )
84108 # JUNGVI: Torch modules name cannot contain "."
85109 newDequantModName = newDequantModName .replace ("." , "_" )
86110
87111 unifiedDequantMod = Dequant (
88- originalModule = oldBiasDequantMod .originalModule ,
89- scale = oldBiasDequantMod . scale ,
90- zeroPoint = oldBiasDequantMod . zeroPoint ,
91- bitWidth = oldBiasDequantMod . bitWidth ,
112+ originalModule = oldDequantMod .originalModule ,
113+ scale = dequantScale ,
114+ zeroPoint = dequantZeroPoint ,
115+ bitWidth = dequantBitWidth ,
92116 )
93117
94118 fxModel .add_module (newDequantModName , unifiedDequantMod )
@@ -105,11 +129,12 @@ def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.Grap
105129 newArgs [i ] = newDequantNode
106130 usr .args = tuple (newArgs )
107131
108- for usr in list (biasDequantNode .users .keys ()):
109- biasDequantNode .users [usr ] = None
110- if hasattr (fxModel , biasDequantNode .target ):
111- delattr (fxModel , biasDequantNode .target )
112- graph .erase_node (biasDequantNode )
132+ if biasDequantNode is not None :
133+ for usr in list (biasDequantNode .users .keys ()):
134+ biasDequantNode .users [usr ] = None
135+ if hasattr (fxModel , biasDequantNode .target ):
136+ delattr (fxModel , biasDequantNode .target )
137+ graph .erase_node (biasDequantNode )
113138
114139 if debug :
115140 print (cc .success (f"Modification done for { node .target } " ))
0 commit comments