|
5 | 5 | # Federico Brancasi <[email protected]> |
6 | 6 |
|
7 | 7 | import torch.fx as fx |
8 | | -import torch |
9 | 8 |
|
10 | 9 | from DeepQuant.QuantManipulation.QuantDequantNodes import Dequant |
11 | 10 | from DeepQuant.Utils.ConsoleFormatter import ConsoleColor as cc |
@@ -76,22 +75,24 @@ def unifyLinearDequants(fxModel: fx.GraphModule, debug: bool = False) -> fx.Grap |
76 | 75 | # otherwise, rely on weight*input |
77 | 76 | if biasDequantNode is not None: |
78 | 77 | oldBiasDequantMod = fxModel.get_submodule(biasDequantNode.target) |
79 | | - dequantScale = oldBiasDequantMod.scale |
| 78 | + dequantScale = oldBiasDequantMod.scale |
80 | 79 | dequantZeroPoint = oldBiasDequantMod.zeroPoint |
81 | | - dequantBitWidth = oldBiasDequantMod.bitWidth |
| 80 | + dequantBitWidth = oldBiasDequantMod.bitWidth |
82 | 81 | oldDequantMod = oldBiasDequantMod |
83 | 82 | else: |
84 | | - oldInputDequantMod = fxModel.get_submodule(inputDequantNode.target) |
| 83 | + oldInputDequantMod = fxModel.get_submodule(inputDequantNode.target) |
85 | 84 | oldWeightDequantMod = fxModel.get_submodule(weightDequantNode.target) |
86 | | - dequantScale = oldWeightDequantMod.scale * oldInputDequantMod.scale |
| 85 | + dequantScale = oldWeightDequantMod.scale * oldInputDequantMod.scale |
87 | 86 | # FCONTI: technically it should be: |
88 | 87 | # dZP = oWDM.zP * oIDM.zP - oWDM.scale * oIDM.zP * sum(weights) |
89 | 88 | # how to appropriately compute sum(weights)? |
90 | 89 | # for now we restrict ourselves to oIDM.zP = 0, so dZP = 0 |
91 | 90 | if debug and oldInputDequantMod.zeroPoint != 0.0: |
92 | | - print(f"Warning: input Dequant node for {node.target} has non-zero zero-point (unsupported). Expect wrong results!") |
| 91 | + print( |
| 92 | + f"Warning: input Dequant node for {node.target} has non-zero zero-point (unsupported). Expect wrong results!" |
| 93 | + ) |
93 | 94 | dequantZeroPoint = 0.0 |
94 | | - dequantBitWidth = 32 # FCONTI: this is simply a reasonable assumption: is there a less arbitrary one? |
| 95 | + dequantBitWidth = 32 # FCONTI: this is simply a reasonable assumption: is there a less arbitrary one? |
95 | 96 | oldDequantMod = oldWeightDequantMod |
96 | 97 |
|
97 | 98 | for dnode in (inputDequantNode, weightDequantNode): |
|
0 commit comments