Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/brevitas/core/scaling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from brevitas.core.stats import SCALAR_SHAPE

from .float_scaling import FloatScaling
from .float_scaling import RoundMidMaxSte
from .int_scaling import IntScaling
from .int_scaling import PowerOfTwoIntScaling
from .pre_scaling import AccumulatorAwareParameterPreScaling
Expand Down
23 changes: 23 additions & 0 deletions src/brevitas/core/scaling/float_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@

import torch
from torch import Tensor
import torch.nn as nn

import brevitas
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_float
from brevitas.function.ops_ste import ceil_ste


class FloatScaling(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -44,3 +46,24 @@ def forward(
max_value = max_value if self.max_available_float is None else torch.min(
max_value, self.max_available_float())
return max_value


@brevitas.jit.script
def _calculate_midmax_bias(mantissa_bit_width: Tensor, midmax_mantissa_bit_bias: float) -> Tensor:
return torch.log2(
(2 -
2 ** (-mantissa_bit_width - 1 + midmax_mantissa_bit_bias))) # extra 1 for the implicit bit


class RoundMidMaxSte(brevitas.jit.ScriptModule):

def __init__(self, mantissa_bit_width_impl: nn.Module, midmax_mantissa_bit_bias: float = 0.0):
super().__init__()
self.mantissa_bit_width_impl = mantissa_bit_width_impl
self.midmax_mantissa_bit_bias = midmax_mantissa_bit_bias

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
return ceil_ste(
x -
_calculate_midmax_bias(self.mantissa_bit_width_impl(), self.midmax_mantissa_bit_bias))
1 change: 1 addition & 0 deletions src/brevitas/core/scaling/int_scaling.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Union

import torch
from torch import nn
from torch import Tensor

import brevitas
Expand Down
22 changes: 11 additions & 11 deletions src/brevitas/quant/experimental/mx_quant_ocp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from brevitas.quant.experimental.float_quant_ocp import FpOCPWeight
from brevitas.quant.solver.act import ActQuantSolver
from brevitas.quant.solver.weight import WeightQuantSolver
from brevitas.utils.float_quant_utils import get_midmax_mantissa_bit_bias


class GroupwiseWeightFloatProxyMixin(ExtendedInjector):
Expand All @@ -52,7 +53,7 @@ class RestrictThresholdMixin(ExtendedInjector):
restrict_scaling_impl = PowerOfTwoRestrictValue


class MXWeightMixin(ExtendedInjector):
class MXMixin(ExtendedInjector):
threshold_mixin = RestrictThresholdMixin
group_size = 32
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
Expand All @@ -63,14 +64,17 @@ class MXWeightMixin(ExtendedInjector):
def restrict_threshold_impl():
return this.threshold_mixin.restrict_scaling_impl

@value
def midmax_mantissa_bit_bias(mantissa_bit_width, nan_values, inf_values):
return get_midmax_mantissa_bit_bias(mantissa_bit_width, nan_values, inf_values)

class MXActMixin(ExtendedInjector):
threshold_mixin = RestrictThresholdMixin
group_size = 32
restrict_scaling_type = RestrictValueType.POWER_OF_TWO
restrict_value_float_to_int_impl = FloorSte

class MXWeightMixin(MXMixin):
pass


class MXActMixin(MXMixin):
scaling_impl = RuntimeDynamicGroupStatsScaling
scaling_per_output_type = ScalingPerOutputType.GROUP

@value
def stats_reduce_dim(group_dim):
Expand All @@ -80,10 +84,6 @@ def stats_reduce_dim(group_dim):
else:
return group_dim + 1

@value
def restrict_threshold_impl():
return this.threshold_mixin.restrict_scaling_impl


class MXFloat8e4m3Weight(MXWeightMixin,
GroupwiseWeightFloatProxyMixin,
Expand Down
17 changes: 17 additions & 0 deletions src/brevitas/utils/float_quant_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,3 +95,20 @@ def get_min_available_float(
min_value = get_minifloat_value(
exponent=exponent, mantissa=mantissa, exponent_bias=exponent_bias)
return min_value


# TODO: Allow dynamically changing this value at runtime
def get_midmax_mantissa_bit_bias(
mantissa_bit_width: int, nan_values: Tuple[str], inf_values: Tuple[str]) -> float:
# Calculate how much bias needs to be added midmax calculation, based on the amount of reserved values for inf, nan
num_inf_values = 0 if inf_values is None else len(inf_values)
num_nan_values = 0 if nan_values is None else len(nan_values)
total_reserved_values = num_inf_values + num_nan_values
excess_reserved_values = total_reserved_values % 2 ** mantissa_bit_width # How many extra values are reserved for the highest valid exponent
if excess_reserved_values == 0:
return 0.0 # No special reserved mantissa values at maximum valid mantissa
elif (excess_reserved_values + 1) == 2 ** mantissa_bit_width:
return 0.0 # Edge case when only f'0{mantissa_bit_width}b' is representable at the maximum mantissa
else:
return torch.log2(torch.tensor(excess_reserved_values + 1)).item(
) # The number of bits of the mantissa that are consumed by the reserved values
7 changes: 5 additions & 2 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from brevitas.core.function_wrapper import CeilSte
from brevitas.core.function_wrapper import FloorSte
from brevitas.core.restrict_val import RoundSte
from brevitas.core.scaling import RoundMidMaxSte
from brevitas.core.stats import NegativeMinOrZero
from brevitas.core.zero_point import ParameterFromStatsFromParameterZeroPoint
from brevitas.graph.quantize import layerwise_quantize
Expand Down Expand Up @@ -306,7 +307,8 @@ def quant_format_from_string(quant_format):
attn_kwargs = dict()

if scale_rounding_func_type is not None:
scale_rounding_func_dict = {'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte}
scale_rounding_func_dict = {
'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte, 'midmax': RoundMidMaxSte}
scale_type = scale_rounding_func_dict[scale_rounding_func_type]
input_kwargs = {**input_kwargs, **{'restrict_value_float_to_int_impl': scale_type}}

Expand Down Expand Up @@ -371,7 +373,8 @@ def quant_format_from_string(quant_format):
**weight_float_format)

if scale_rounding_func_type is not None:
scale_rounding_func_dict = {'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte}
scale_rounding_func_dict = {
'ceil': CeilSte, 'floor': FloorSte, 'round': RoundSte, 'midmax': RoundMidMaxSte}
scale_type = scale_rounding_func_dict[scale_rounding_func_type]
weight_quant = weight_quant.let(**{'restrict_value_float_to_int_impl': scale_type})

Expand Down
2 changes: 1 addition & 1 deletion src/brevitas_examples/llm/llm_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def create_args_parser() -> ArgumentParser:
'--scale-rounding-func-type',
type=str,
default=None,
choices=['round', 'ceil', 'floor'],
choices=['round', 'ceil', 'floor', 'midmax'],
help='Rounding function to use with Po2 scale. Default: None.')
parser.add_argument(
'--weight-group-dim',
Expand Down
41 changes: 33 additions & 8 deletions tests/brevitas/core/test_quant_mx.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,12 @@
from typing import Tuple
from typing import Union

from hypothesis import example
from hypothesis import given
import pytest_cases
import torch

from brevitas.core.scaling import RoundMidMaxSte
from brevitas.nn.quant_activation import QuantIdentity
from brevitas.nn.quant_linear import QuantLinear
from brevitas.quant.experimental.mx_quant_ocp import MXFloat8e4m3Act
Expand Down Expand Up @@ -139,38 +141,60 @@ def quantize(self, tensor: torch.Tensor, axis: int = -1, select: bool = False):


MAP = {"e4m3": (4, 3), "e5m2": (5, 2), "e2m3": (2, 3), "e3m2": (3, 2), "e2m1": (2, 1)}
SCALE_ROUNDING = ["floor", "midmax"]
RANGE = 1e9

edge_cases = [
torch.tensor([[7.5 if i == 0 else 1.0 for i in range(32)]]), # MXFP4, MidMax
torch.tensor([[3.875 if i == 0 else 1.0 for i in range(32)]]), # MXFP8E4M3, MidMax
]

@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10))

@example(inp=edge_cases[0])
@example(inp=edge_cases[1])
@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=RANGE, min_val=-RANGE))
@pytest_cases.parametrize('bit_widths', list(MAP.keys()))
def test_act_mx(inp, bit_widths):
@pytest_cases.parametrize('scale_rounding', SCALE_ROUNDING)
def test_act_mx(inp, bit_widths, scale_rounding):
torch.set_printoptions(precision=12, sci_mode=False)
exp, mant = MAP[bit_widths]

extra_kwargs = {}
# Default rounding should be 'floor', so only create this dict if we are overriding the default
if scale_rounding == "midmax":
extra_kwargs["restrict_value_float_to_int_impl"] = RoundMidMaxSte
act_quant = QuantIdentity(
MXFloat8e4m3Act,
exponent_bit_width=exp,
mantissa_bit_width=mant,
bit_width=mant + exp + 1,
group_dim=1,
return_quant_tensor=True)
return_quant_tensor=True,
**extra_kwargs)
act_quant.eval()
x = inp

quantizer = MXFP(bit_widths)

qx = act_quant(x)

y = quantizer.quantize(x)
y = quantizer.quantize(x, select=scale_rounding == "midmax")
assert torch.allclose(qx.value, y, atol=1e-8)


@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=1e10, min_val=-1e10))
@example(inp=edge_cases[0])
@example(inp=edge_cases[1])
@given(inp=float_tensor_nz_st(shape=(1, 32), max_val=RANGE, min_val=-RANGE))
@pytest_cases.parametrize('bit_widths', list(MAP.keys()))
@pytest_cases.parametrize('scale_rounding', SCALE_ROUNDING)
@pytest_cases.parametrize('weight_quant_type', ['stats', 'parameter_from_stats'])
def test_weight_mx(inp, bit_widths, weight_quant_type):
def test_weight_mx(inp, bit_widths, scale_rounding, weight_quant_type):
torch.set_printoptions(precision=12, sci_mode=False)
exp, mant = MAP[bit_widths]
extra_kwargs = {}
# Default rounding should be 'floor', so only create this dict if we are overriding the default
if scale_rounding == "midmax":
extra_kwargs["weight_restrict_value_float_to_int_impl"] = RoundMidMaxSte
weight_quant = QuantLinear(
32,
1,
Expand All @@ -179,7 +203,8 @@ def test_weight_mx(inp, bit_widths, weight_quant_type):
weight_scaling_impl_type=weight_quant_type,
weight_exponent_bit_width=exp,
weight_mantissa_bit_width=mant,
weight_bit_width=mant + exp + 1)
weight_bit_width=mant + exp + 1,
**extra_kwargs)

x = inp
weight_quant.weight.data = x
Expand All @@ -189,6 +214,6 @@ def test_weight_mx(inp, bit_widths, weight_quant_type):
qx_weight = weight_quant.quant_weight()
qx_weight_two = weight_quant.quant_weight()

y = quantizer.quantize(x)
y = quantizer.quantize(x, select=scale_rounding == "midmax")
assert torch.allclose(qx_weight.value, y, atol=1e-8)
assert torch.allclose(qx_weight_two.value, y, atol=1e-8)
Loading