Skip to content

Commit b998c38

Browse files
Rename for better understanding
1 parent 516605e commit b998c38

File tree

5 files changed

+18
-18
lines changed

5 files changed

+18
-18
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from torch import Tensor
1010

1111

12-
class WrapperLinear(nn.Module):
12+
class WrapperWBIOL(nn.Module):
1313
"""Expose `inner_forward_impl` as a standalone submodule."""
1414

1515
def __init__(self, innerForwardImpl: nn.Module) -> None:
@@ -22,7 +22,7 @@ def forward(
2222
return self.innerForwardImpl(quantInput, quantWeight, quantBias)
2323

2424

25-
def linearForward(self: QuantWeightBiasInputOutputLayer, inp: Tensor) -> Tensor:
25+
def WBIOLForward(self: QuantWeightBiasInputOutputLayer, inp: Tensor) -> Tensor:
2626
"""Quant-in → quant-weight/bias → matmul → quant-out."""
2727
quantInput = self.input_quant(inp)
2828
quantWeight = self.weight_quant(self.weight)

DeepQuant/Pipeline/QuantSplit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import torch
1010
import torch.nn as nn
1111

12-
from DeepQuant.QuantManipulation.ParameterExtractor import (
12+
from DeepQuant.QuantManipulation.QuantizationParameterExtractor import (
1313
extractBrevitasProxyParams,
1414
printQuantParams,
1515
)

DeepQuant/QuantManipulation/ParameterExtractor.py renamed to DeepQuant/QuantManipulation/QuantizationParameterExtractor.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector
1616

1717

18-
def safeGetScale(quantObj: Any) -> Any:
19-
"""Safely extract scale parameter from quantization object."""
18+
def getScale(quantObj: Any) -> Any:
19+
"""Extract scale parameter from quantization object."""
2020
if quantObj is None:
2121
return None
2222
maybeScale = quantObj.scale() if callable(quantObj.scale) else quantObj.scale
@@ -32,8 +32,8 @@ def safeGetScale(quantObj: Any) -> Any:
3232
return None
3333

3434

35-
def safeGetZeroPoint(quantObj: Any) -> Any:
36-
"""Safely extract zero point parameter from quantization object."""
35+
def getZeroPoint(quantObj: Any) -> Any:
36+
"""Extract zero point parameter from quantization object."""
3737
if quantObj is None:
3838
return None
3939
maybeZp = (
@@ -51,16 +51,16 @@ def safeGetZeroPoint(quantObj: Any) -> Any:
5151
return None
5252

5353

54-
def safeGetIsSigned(quantObj: Any) -> bool:
55-
"""Safely determine if quantization is signed."""
54+
def getIsSigned(quantObj: Any) -> bool:
55+
"""Determine if quantization is signed."""
5656
if hasattr(quantObj, "is_signed"):
5757
return getattr(quantObj, "is_signed")
5858
if hasattr(quantObj, "min_val"):
5959
try:
6060
return quantObj.min_val < 0
6161
except Exception:
6262
pass
63-
zp = safeGetZeroPoint(quantObj)
63+
zp = getZeroPoint(quantObj)
6464
if zp is not None:
6565
# If zero_point is near zero, assume unsigned quantization.
6666
return not (abs(zp) < 1e-5)
@@ -82,10 +82,10 @@ def recurseModules(parentMod: nn.Module, prefix: str = "") -> None:
8282
BiasQuantProxyFromInjector,
8383
),
8484
):
85-
scl = safeGetScale(childMod)
86-
zp = safeGetZeroPoint(childMod)
85+
scl = getScale(childMod)
86+
zp = getZeroPoint(childMod)
8787
bw = childMod.bit_width()
88-
isSigned = safeGetIsSigned(childMod)
88+
isSigned = getIsSigned(childMod)
8989
paramsDict[fullName] = {
9090
"scale": scl,
9191
"zero_point": zp,

DeepQuant/Transforms/Transformations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from brevitas.nn.quant_mha import QuantMultiheadAttention
1515

1616
from DeepQuant.CustomForwards.Activations import WrapperActivation, activationForward
17-
from DeepQuant.CustomForwards.Linear import WrapperLinear, linearForward
17+
from DeepQuant.CustomForwards.WBIOL import WBIOLForward, WrapperWBIOL
1818
from DeepQuant.CustomForwards.MultiHeadAttention import mhaForward
1919
from DeepQuant.Transforms.Base import TransformationPass
2020
from DeepQuant.Utils.CustomTracer import QuantTracer
@@ -33,11 +33,11 @@ def injectForward(
3333
self, module: nn.Module, tracer: Optional[QuantTracer] = None
3434
) -> None:
3535
"""Inject custom forward for linear layers."""
36-
module.wrappedInnerForwardImpl = WrapperLinear(module.inner_forward_impl)
37-
module.forward = linearForward.__get__(module)
36+
module.wrappedInnerForwardImpl = WrapperWBIOL(module.inner_forward_impl)
37+
module.forward = WBIOLForward.__get__(module)
3838

3939
if tracer:
40-
tracer.registerLeafModule(WrapperLinear)
40+
tracer.registerLeafModule(WrapperWBIOL)
4141
tracer.registerNonLeafModule(QuantWeightBiasInputOutputLayer)
4242

4343

Tests/TestYOLOv5.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def prepareYOLOv5Backbone() -> nn.Module:
2323
"""Prepare a quantized partial YOLOv5 model for testing."""
2424
from ultralytics import YOLO
2525

26-
model = YOLO("Models/yolov5n.pt")
26+
model = YOLO("Models/yolov5nu.pt")
2727
pytorchModel = model.model
2828

2929
# FBRANCASI: Just first few layers for simplicity

0 commit comments

Comments
 (0)