Skip to content

Commit 587494a

Browse files
authored
feat (quant/mx): Added midmax scale rounding option to MX types (#1409)
1 parent 1d1fd5a commit 587494a

File tree

9 files changed

+93
-22
lines changed

9 files changed

+93
-22
lines changed

src/brevitas/core/scaling/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from brevitas.core.stats import SCALAR_SHAPE
1111

1212
from .float_scaling import FloatScaling
13+
from .float_scaling import RoundMidMaxSte
1314
from .int_scaling import IntScaling
1415
from .int_scaling import PowerOfTwoIntScaling
1516
from .pre_scaling import AccumulatorAwareParameterPreScaling

src/brevitas/core/scaling/float_scaling.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,13 @@
77

88
import torch
99
from torch import Tensor
10+
import torch.nn as nn
1011

1112
import brevitas
1213
from brevitas.core.utils import StatelessBuffer
14+
from brevitas.function.ops import calculate_midmax_bias
1315
from brevitas.function.ops import max_float
16+
from brevitas.function.ops_ste import ceil_ste
1417

1518

1619
class FloatScaling(brevitas.jit.ScriptModule):
@@ -44,3 +47,17 @@ def forward(
4447
max_value = max_value if self.max_available_float is None else torch.min(
4548
max_value, self.max_available_float())
4649
return max_value
50+
51+
52+
class RoundMidMaxSte(brevitas.jit.ScriptModule):
53+
54+
def __init__(self, mantissa_bit_width_impl: nn.Module, midmax_mantissa_bit_bias: float = 0.0):
55+
super().__init__()
56+
self.mantissa_bit_width_impl = mantissa_bit_width_impl
57+
self.midmax_mantissa_bit_bias = midmax_mantissa_bit_bias
58+
59+
@brevitas.jit.script_method
60+
def forward(self, x: Tensor) -> Tensor:
61+
return ceil_ste(
62+
x -
63+
calculate_midmax_bias(self.mantissa_bit_width_impl(), self.midmax_mantissa_bit_bias))

src/brevitas/core/scaling/int_scaling.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Union
66

77
import torch
8+
from torch import nn
89
from torch import Tensor
910

1011
import brevitas

src/brevitas/function/ops.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,13 @@ def max_float(exponent_bit_width: Tensor, max_mantissa: Tensor, exponent_bias: T
212212
return max_val
213213

214214

215+
@brevitas.jit.script
216+
def calculate_midmax_bias(mantissa_bit_width: Tensor, midmax_mantissa_bit_bias: float) -> Tensor:
217+
return torch.log2(
218+
(2 -
219+
2 ** (-mantissa_bit_width - 1 + midmax_mantissa_bit_bias))) # extra 1 for the implicit bit
220+
221+
215222
def get_upper_bound_on_l1_norm(
216223
accumulator_bit_width: Tensor, input_bit_width: Tensor, input_is_signed: bool) -> Tensor:
217224
"""Calculate the upper bound on the l1-norm of the weights needed to guarantee overflow avoidance

src/brevitas/quant/experimental/mx_quant_ocp.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
from brevitas.quant.experimental.float_quant_ocp import FpOCPWeight
3030
from brevitas.quant.solver.act import ActQuantSolver
3131
from brevitas.quant.solver.weight import WeightQuantSolver
32+
from brevitas.utils.float_quant_utils import get_midmax_mantissa_bit_bias
3233

3334

3435
class GroupwiseWeightFloatProxyMixin(ExtendedInjector):
@@ -52,7 +53,7 @@ class RestrictThresholdMixin(ExtendedInjector):
5253
restrict_scaling_impl = PowerOfTwoRestrictValue
5354

5455

55-
class MXWeightMixin(ExtendedInjector):
56+
class MXMixin(ExtendedInjector):
5657
threshold_mixin = RestrictThresholdMixin
5758
group_size = 32
5859
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
@@ -63,14 +64,17 @@ class MXWeightMixin(ExtendedInjector):
6364
def restrict_threshold_impl():
6465
return this.threshold_mixin.restrict_scaling_impl
6566

67+
@value
68+
def midmax_mantissa_bit_bias(mantissa_bit_width, nan_values, inf_values):
69+
return get_midmax_mantissa_bit_bias(mantissa_bit_width, nan_values, inf_values)
6670

67-
class MXActMixin(ExtendedInjector):
68-
threshold_mixin = RestrictThresholdMixin
69-
group_size = 32
70-
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
71-
restrict_value_float_to_int_impl = FloorSte
71+
72+
class MXWeightMixin(MXMixin):
73+
pass
74+
75+
76+
class MXActMixin(MXMixin):
7277
scaling_impl = RuntimeDynamicGroupStatsScaling
73-
scaling_per_output_type = ScalingPerOutputType.GROUP
7478

7579
@value
7680
def stats_reduce_dim(group_dim):
@@ -80,10 +84,6 @@ def stats_reduce_dim(group_dim):
8084
else:
8185
return group_dim + 1
8286

83-
@value
84-
def restrict_threshold_impl():
85-
return this.threshold_mixin.restrict_scaling_impl
86-
8787

8888
class MXFloat8e4m3Weight(MXWeightMixin,
8989
GroupwiseWeightFloatProxyMixin,

src/brevitas/utils/float_quant_utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,20 @@ def get_min_available_float(
9595
min_value = get_minifloat_value(
9696
exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias)
9797
return min_value
98+
99+
100+
# TODO: Allow dynamically changing this value at runtime
101+
def get_midmax_mantissa_bit_bias(
102+
mantissa_bit_width: int, nan_values: Tuple[str], inf_values: Tuple[str]) -> float:
103+
# Calculate how much bias needs to be added midmax calculation, based on the amount of reserved values for inf, nan
104+
num_inf_values = 0 if inf_values is None else len(inf_values)
105+
num_nan_values = 0 if nan_values is None else len(nan_values)
106+
total_reserved_values = num_inf_values + num_nan_values
107+
excess_reserved_values = total_reserved_values % 2 ** mantissa_bit_width # How many extra values are reserved for the highest valid exponent
108+
if excess_reserved_values == 0:
109+
return 0.0 # No special reserved mantissa values at maximum valid mantissa
110+
elif (excess_reserved_values + 1) == 2 ** mantissa_bit_width:
111+
return 0.0 # Edge case when only f'0{mantissa_bit_width}b' is representable at the maximum mantissa
112+
else:
113+
return torch.log2(torch.tensor(excess_reserved_values + 1)).item(
114+
) # The number of bits of the mantissa that are consumed by the reserved values

src/brevitas_examples/common/generative/quantize.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from brevitas.core.function_wrapper import CeilSte
1313
from brevitas.core.function_wrapper import FloorSte
1414
from brevitas.core.restrict_val import RoundSte
15+
from brevitas.core.scaling import RoundMidMaxSte
1516
from brevitas.core.stats import NegativeMinOrZero
1617
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
1718
from brevitas.graph.quantize import layerwise_quantize
@@ -306,7 +307,8 @@ def quant_format_from_string(quant_format):
306307
attn_kwargs = dict()
307308

308309
if scale_rounding_func_type is not None:
309-
scale_rounding_func_dict = {'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte}
310+
scale_rounding_func_dict = {
311+
'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte, 'midmax': RoundMidMaxSte}
310312
scale_type = scale_rounding_func_dict[scale_rounding_func_type]
311313
input_kwargs = {**input_kwargs, **{'restrict_value_float_to_int_impl': scale_type}}
312314

@@ -371,7 +373,8 @@ def quant_format_from_string(quant_format):
371373
**weight_float_format)
372374

373375
if scale_rounding_func_type is not None:
374-
scale_rounding_func_dict = {'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte}
376+
scale_rounding_func_dict = {
377+
'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte, 'midmax': RoundMidMaxSte}
375378
scale_type = scale_rounding_func_dict[scale_rounding_func_type]
376379
weight_quant = weight_quant.let(**{'restrict_value_float_to_int_impl': scale_type})
377380

src/brevitas_examples/llm/llm_args.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def create_args_parser() -> ArgumentParser:
113113
'--scale-rounding-func-type',
114114
type=str,
115115
default=None,
116-
choices=['round', 'ceil', 'floor'],
116+
choices=['round', 'ceil', 'floor', 'midmax'],
117117
help='Rounding function to use with Po2 scale. Default: None.')
118118
parser.add_argument(
119119
'--weight-group-dim',

tests/brevitas/core/test_quant_mx.py

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
from typing import Tuple
1212
from typing import Union
1313

14+
from hypothesis import example
1415
from hypothesis import given
1516
import pytest_cases
1617
import torch
1718

19+
from brevitas.core.scaling import RoundMidMaxSte
1820
from brevitas.nn.quant_activation import QuantIdentity
1921
from brevitas.nn.quant_linear import QuantLinear
2022
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
@@ -139,38 +141,60 @@ def quantize(self, tensor: torch.Tensor, axis: int = -1, select: bool = False):
139141

140142

141143
MAP = {"e4m3": (4, 3), "e5m2": (5, 2), "e2m3": (2, 3), "e3m2": (3, 2), "e2m1": (2, 1)}
144+
SCALE_ROUNDING = ["floor", "midmax"]
145+
RANGE = 1e9
142146

147+
edge_cases = [
148+
torch.tensor([[7.5 if i == 0 else 1.0 for i in range(32)]]), # MXFP4, MidMax
149+
torch.tensor([[3.875 if i == 0 else 1.0 for i in range(32)]]), # MXFP8E4M3, MidMax
150+
]
143151

144-
@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10))
152+
153+
@example(inp=edge_cases[0])
154+
@example(inp=edge_cases[1])
155+
@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=RANGE, min_val=-RANGE))
145156
@pytest_cases.parametrize('bit_widths', list(MAP.keys()))
146-
def test_act_mx(inp, bit_widths):
157+
@pytest_cases.parametrize('scale_rounding', SCALE_ROUNDING)
158+
def test_act_mx(inp, bit_widths, scale_rounding):
147159
torch.set_printoptions(precision=12, sci_mode=False)
148160
exp, mant = MAP[bit_widths]
149161

162+
extra_kwargs = {}
163+
# Default rounding should be 'floor', so only create this dict if we are overriding the default
164+
if scale_rounding == "midmax":
165+
extra_kwargs["restrict_value_float_to_int_impl"] = RoundMidMaxSte
150166
act_quant = QuantIdentity(
151167
MXFloat8e4m3Act,
152168
exponent_bit_width=exp,
153169
mantissa_bit_width=mant,
154170
bit_width=mant + exp + 1,
155171
group_dim=1,
156-
return_quant_tensor=True)
172+
return_quant_tensor=True,
173+
**extra_kwargs)
157174
act_quant.eval()
158175
x = inp
159176

160177
quantizer = MXFP(bit_widths)
161178

162179
qx = act_quant(x)
163180

164-
y = quantizer.quantize(x)
181+
y = quantizer.quantize(x, select=scale_rounding == "midmax")
165182
assert torch.allclose(qx.value, y, atol=1e-8)
166183

167184

168-
@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10))
185+
@example(inp=edge_cases[0])
186+
@example(inp=edge_cases[1])
187+
@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=RANGE, min_val=-RANGE))
169188
@pytest_cases.parametrize('bit_widths', list(MAP.keys()))
189+
@pytest_cases.parametrize('scale_rounding', SCALE_ROUNDING)
170190
@pytest_cases.parametrize('weight_quant_type', ['stats', 'parameter_from_stats'])
171-
def test_weight_mx(inp, bit_widths, weight_quant_type):
191+
def test_weight_mx(inp, bit_widths, scale_rounding, weight_quant_type):
172192
torch.set_printoptions(precision=12, sci_mode=False)
173193
exp, mant = MAP[bit_widths]
194+
extra_kwargs = {}
195+
# Default rounding should be 'floor', so only create this dict if we are overriding the default
196+
if scale_rounding == "midmax":
197+
extra_kwargs["weight_restrict_value_float_to_int_impl"] = RoundMidMaxSte
174198
weight_quant = QuantLinear(
175199
32,
176200
1,
@@ -179,7 +203,8 @@ def test_weight_mx(inp, bit_widths, weight_quant_type):
179203
weight_scaling_impl_type=weight_quant_type,
180204
weight_exponent_bit_width=exp,
181205
weight_mantissa_bit_width=mant,
182-
weight_bit_width=mant + exp + 1)
206+
weight_bit_width=mant + exp + 1,
207+
**extra_kwargs)
183208

184209
x = inp
185210
weight_quant.weight.data = x
@@ -189,6 +214,6 @@ def test_weight_mx(inp, bit_widths, weight_quant_type):
189214
qx_weight = weight_quant.quant_weight()
190215
qx_weight_two = weight_quant.quant_weight()
191216

192-
y = quantizer.quantize(x)
217+
y = quantizer.quantize(x, select=scale_rounding == "midmax")
193218
assert torch.allclose(qx_weight.value, y, atol=1e-8)
194219
assert torch.allclose(qx_weight_two.value, y, atol=1e-8)

0 commit comments

Comments
 (0)