1111from typing import Tuple
1212from typing import Union
1313
14+ from hypothesis import example
1415from hypothesis import given
1516import pytest_cases
1617import torch
1718
19+ from brevitas .core .scaling import RoundMidMaxSte
1820from brevitas .nn .quant_activation import QuantIdentity
1921from brevitas .nn .quant_linear import QuantLinear
2022from brevitas .quant .experimental .mx_quant_ocp import MXFloat8e4m3Act
@@ -139,38 +141,60 @@ def quantize(self, tensor: torch.Tensor, axis: int = -1, select: bool = False):
139141
140142
141143MAP = {"e4m3" : (4 , 3 ), "e5m2" : (5 , 2 ), "e2m3" : (2 , 3 ), "e3m2" : (3 , 2 ), "e2m1" : (2 , 1 )}
144+ SCALE_ROUNDING = ["floor" , "midmax" ]
145+ RANGE = 1e9
142146
147+ edge_cases = [
148+ torch .tensor ([[7.5 if i == 0 else 1.0 for i in range (32 )]]), # MXFP4, MidMax
149+ torch .tensor ([[3.875 if i == 0 else 1.0 for i in range (32 )]]), # MXFP8E4M3, MidMax
150+ ]
143151
144- @given (inp = float_tensor_nz_st (shape = (1 , 32 ), max_val = 1e10 , min_val = - 1e10 ))
152+
153+ @example (inp = edge_cases [0 ])
154+ @example (inp = edge_cases [1 ])
155+ @given (inp = float_tensor_nz_st (shape = (1 , 32 ), max_val = RANGE , min_val = - RANGE ))
145156@pytest_cases .parametrize ('bit_widths' , list (MAP .keys ()))
146- def test_act_mx (inp , bit_widths ):
157+ @pytest_cases .parametrize ('scale_rounding' , SCALE_ROUNDING )
158+ def test_act_mx (inp , bit_widths , scale_rounding ):
147159 torch .set_printoptions (precision = 12 , sci_mode = False )
148160 exp , mant = MAP [bit_widths ]
149161
162+ extra_kwargs = {}
163+ # Default rounding should be 'floor', so only create this dict if we are overriding the default
164+ if scale_rounding == "midmax" :
165+ extra_kwargs ["restrict_value_float_to_int_impl" ] = RoundMidMaxSte
150166 act_quant = QuantIdentity (
151167 MXFloat8e4m3Act ,
152168 exponent_bit_width = exp ,
153169 mantissa_bit_width = mant ,
154170 bit_width = mant + exp + 1 ,
155171 group_dim = 1 ,
156- return_quant_tensor = True )
172+ return_quant_tensor = True ,
173+ ** extra_kwargs )
157174 act_quant .eval ()
158175 x = inp
159176
160177 quantizer = MXFP (bit_widths )
161178
162179 qx = act_quant (x )
163180
164- y = quantizer .quantize (x )
181+ y = quantizer .quantize (x , select = scale_rounding == "midmax" )
165182 assert torch .allclose (qx .value , y , atol = 1e-8 )
166183
167184
168- @given (inp = float_tensor_nz_st (shape = (1 , 32 ), max_val = 1e10 , min_val = - 1e10 ))
185+ @example (inp = edge_cases [0 ])
186+ @example (inp = edge_cases [1 ])
187+ @given (inp = float_tensor_nz_st (shape = (1 , 32 ), max_val = RANGE , min_val = - RANGE ))
169188@pytest_cases .parametrize ('bit_widths' , list (MAP .keys ()))
189+ @pytest_cases .parametrize ('scale_rounding' , SCALE_ROUNDING )
170190@pytest_cases .parametrize ('weight_quant_type' , ['stats' , 'parameter_from_stats' ])
171- def test_weight_mx (inp , bit_widths , weight_quant_type ):
191+ def test_weight_mx (inp , bit_widths , scale_rounding , weight_quant_type ):
172192 torch .set_printoptions (precision = 12 , sci_mode = False )
173193 exp , mant = MAP [bit_widths ]
194+ extra_kwargs = {}
195+ # Default rounding should be 'floor', so only create this dict if we are overriding the default
196+ if scale_rounding == "midmax" :
197+ extra_kwargs ["weight_restrict_value_float_to_int_impl" ] = RoundMidMaxSte
174198 weight_quant = QuantLinear (
175199 32 ,
176200 1 ,
@@ -179,7 +203,8 @@ def test_weight_mx(inp, bit_widths, weight_quant_type):
179203 weight_scaling_impl_type = weight_quant_type ,
180204 weight_exponent_bit_width = exp ,
181205 weight_mantissa_bit_width = mant ,
182- weight_bit_width = mant + exp + 1 )
206+ weight_bit_width = mant + exp + 1 ,
207+ ** extra_kwargs )
183208
184209 x = inp
185210 weight_quant .weight .data = x
@@ -189,6 +214,6 @@ def test_weight_mx(inp, bit_widths, weight_quant_type):
189214 qx_weight = weight_quant .quant_weight ()
190215 qx_weight_two = weight_quant .quant_weight ()
191216
192- y = quantizer .quantize (x )
217+ y = quantizer .quantize (x , select = scale_rounding == "midmax" )
193218 assert torch .allclose (qx_weight .value , y , atol = 1e-8 )
194219 assert torch .allclose (qx_weight_two .value , y , atol = 1e-8 )
0 commit comments