2626
2727
2828@register_dtype (("block_fp8_sym" , "block_fp8" , "block_fp8_e4m3" ))
29- def quant_block_fp_sym (tensor , max_scale = 1.0 , tensor_max = None , group_size = (128 , 128 ), v = 0 , ** kwargs ):
29+ def quant_block_fp_sym (tensor , max_scale = 1.0 , tensor_max = None , group_size = (128 , 128 ), v = 0 , tensor_min = None , ** kwargs ):
3030 """Symmetric quantization using block float8 format.
3131
3232 Args:
@@ -51,9 +51,18 @@ def quant_block_fp_sym(tensor, max_scale=1.0, tensor_max=None, group_size=(128,
5151 if tensor_max is None :
5252 max_tensor = tensor .abs ().amax (dim = (- 2 , - 1 )) * max_scale
5353 elif isinstance (tensor_max , torch .Tensor ):
54- max_tensor = tensor_max .to (tensor .device ) * max_scale
54+ max_tensor = (
55+ tensor_max .to (tensor .device ) * max_scale
56+ if tensor_min is None
57+ else torch .maximum (tensor_max .abs (), tensor_min .abs ()).to (tensor .device ) * max_scale
58+ )
5559 else :
56- max_tensor = torch .tensor (tensor_max ).to (tensor .device ) * max_scale
60+ max_tensor = (
61+ torch .tensor (tensor_max ).to (tensor .device ) * max_scale
62+ if tensor_min is None
63+ else torch .maximum (torch .tensor (tensor_max ).abs (), torch .tensor (tensor_min ).abs ()).to (tensor .device )
64+ * max_scale
65+ )
5766 scale = max_tensor / info .max
5867 assert len (scale .shape ) == 2 , f"Only support 2D group_size, but get { len (scale .shape )} "
5968 min_scaling_factor = float (1.0 / (info .max * 512.0 )) ##copy from vllm
@@ -71,7 +80,7 @@ def quant_block_fp_sym(tensor, max_scale=1.0, tensor_max=None, group_size=(128,
7180
7281
7382@register_dtype (("fp8_sym" , "fp8" , "fp8_e4m3" ))
74- def quant_fp8_sym (tensor , max_scale = 1.0 , tensor_max = None , group_size = - 1 , v = 0 , ** kwargs ):
83+ def quant_fp8_sym (tensor , max_scale = 1.0 , tensor_max = None , group_size = - 1 , v = 0 , tensor_min = None , ** kwargs ):
7584 """Symmetric quantization using float8 format.
7685
7786 Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
@@ -98,9 +107,18 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **
98107 if tensor_max is None : ##dynamic per-token
99108 max_tensor = torch .max (torch .abs (tensor ), dim = - 1 )[0 ] * max_scale
100109 elif isinstance (tensor_max , torch .Tensor ):
101- max_tensor = tensor_max .to (tensor .device ) * max_scale
110+ max_tensor = (
111+ tensor_max .to (tensor .device ) * max_scale
112+ if tensor_min is None
113+ else torch .maximum (tensor_max .abs (), tensor_min .abs ()).to (tensor .device ) * max_scale
114+ )
102115 else :
103- max_tensor = torch .tensor (tensor_max ).to (tensor .device ) * max_scale
116+ max_tensor = (
117+ torch .tensor (tensor_max ).to (tensor .device ) * max_scale
118+ if tensor_min is None
119+ else torch .maximum (torch .tensor (tensor_max ).abs (), torch .tensor (tensor_min ).abs ()).to (tensor .device )
120+ * max_scale
121+ )
104122 scale = max_tensor .to (torch .float32 ) / info .max
105123 min_scaling_factor = float (1.0 / (info .max * 512.0 )) ##copy from vllm
106124 scale = torch .clip (scale , min = min_scaling_factor )
@@ -117,7 +135,7 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **
117135
118136
119137@register_dtype ("fp8_e5m2" )
120- def quant_fp8_e5m2 (tensor , max_scale = 1.0 , tensor_max = None , group_size = - 1 , v = 0 , ** kwargs ):
138+ def quant_fp8_e5m2 (tensor , max_scale = 1.0 , tensor_max = None , group_size = - 1 , v = 0 , tensor_min = None , ** kwargs ):
121139 """Symmetric quantization using float8 format.
122140
123141 Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
@@ -140,9 +158,18 @@ def quant_fp8_e5m2(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, *
140158 if tensor_max is None : ##dynamic per-token
141159 max_tensor = torch .max (torch .abs (tensor ), dim = - 1 )[0 ] * max_scale
142160 elif isinstance (tensor_max , torch .Tensor ):
143- max_tensor = tensor_max .to (tensor .device ) * max_scale
161+ max_tensor = (
162+ tensor_max .to (tensor .device ) * max_scale
163+ if tensor_min is None
164+ else torch .maximum (tensor_max .abs (), tensor_min .abs ()).to (tensor .device ) * max_scale
165+ )
144166 else :
145- max_tensor = torch .tensor (tensor_max ).to (tensor .device ) * max_scale
167+ max_tensor = (
168+ torch .tensor (tensor_max ).to (tensor .device ) * max_scale
169+ if tensor_min is None
170+ else torch .maximum (torch .tensor (tensor_max ).abs (), torch .tensor (tensor_min ).abs ()).to (tensor .device )
171+ * max_scale
172+ )
146173 scale = max_tensor .to (torch .float32 ) / info .max
147174 min_scaling_factor = float (1.0 / (info .max * 512.0 )) ##copy from vllm
148175 scale = torch .clip (scale , min = min_scaling_factor )
@@ -225,7 +252,7 @@ def quant_fp8_e5m2_unit_scale(tensor, max_scale=1.0, tensor_max=None, group_size
225252
226253
227254@register_dtype ("fp8_gaudi3_sym" )
228- def quant_fp8_sym_gaudi3 (tensor , max_scale = 1.0 , tensor_max = None , ** kwargs ):
255+ def quant_fp8_sym_gaudi3 (tensor , max_scale = 1.0 , tensor_max = None , tensor_min = None , ** kwargs ):
229256 """Symmetric quantization using float8 format.
230257
231258 Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
@@ -250,9 +277,19 @@ def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, **kwargs):
250277 tensor = tensor .reshape (- 1 , orig_shape [- 1 ])
251278 max_tensor = torch .max (torch .abs (tensor ), dim = - 1 )[0 ] * max_scale
252279 elif isinstance (tensor_max , torch .Tensor ):
253- max_tensor = tensor_max .clone ().detach ().to (tensor .device ) * max_scale
280+ max_tensor = (
281+ tensor_max .to (tensor .device ) * max_scale
282+ if tensor_min is None
283+ else torch .maximum (tensor_max .clone ().detach ().abs (), tensor_min .clone ().detach ().abs ()).to (tensor .device )
284+ * max_scale
285+ )
254286 else :
255- max_tensor = torch .tensor (tensor_max ).to (tensor .device ) * max_scale
287+ max_tensor = (
288+ torch .tensor (tensor_max ).to (tensor .device ) * max_scale
289+ if tensor_min is None
290+ else torch .maximum (torch .tensor (tensor_max ).abs (), torch .tensor (tensor_min ).abs ()).to (tensor .device )
291+ * max_scale
292+ )
256293 scale = max_tensor .to (torch .float32 ) / fp8_max
257294 min_scaling_factor = float (1.0 / (fp8_max * 512.0 )) ##copy from vllm
258295 scale = torch .clip (scale , min = min_scaling_factor )
@@ -271,7 +308,9 @@ def quant_fp8_sym_gaudi3(tensor, max_scale=1.0, tensor_max=None, **kwargs):
271308if is_gaudi2 ():
272309
273310 @register_dtype (("fp8_sym" , "fp8" , "fp8_e4m3" ))
274- def quant_fp8_sym (tensor , max_scale = 1.0 , tensor_max = None , group_size = - 1 , v = 0 , ** kwargs ): # pylint: disable=E0102
311+ def quant_fp8_sym (
312+ tensor , max_scale = 1.0 , tensor_max = None , group_size = - 1 , v = 0 , tensor_min = None , ** kwargs
313+ ): # pylint: disable=E0102
275314 """Symmetric quantization using float8 format.
276315
277316 Allows both dynamic per-token scaling and tensor-wide quantization depending on input.
@@ -300,9 +339,18 @@ def quant_fp8_sym(tensor, max_scale=1.0, tensor_max=None, group_size=-1, v=0, **
300339 if tensor_max is None : ##dynamic per-token
301340 max_tensor = torch .max (torch .abs (tensor ), dim = - 1 )[0 ] * max_scale
302341 elif isinstance (tensor_max , torch .Tensor ):
303- max_tensor = tensor_max .to (tensor .device ) * max_scale
342+ max_tensor = (
343+ tensor_max .to (tensor .device ) * max_scale
344+ if tensor_min is None
345+ else torch .maximum (tensor_max .abs (), tensor_min .abs ()).to (tensor .device ) * max_scale
346+ )
304347 else :
305- max_tensor = torch .tensor (tensor_max ).to (tensor .device ) * max_scale
348+ max_tensor = (
349+ torch .tensor (tensor_max ).to (tensor .device ) * max_scale
350+ if tensor_min is None
351+ else torch .maximum (torch .tensor (tensor_max ).abs (), torch .tensor (tensor_min ).abs ()).to (tensor .device )
352+ * max_scale
353+ )
306354 scale = max_tensor .to (torch .float32 ) / info .max
307355 min_scaling_factor = float (1.0 / (info .max * 512.0 )) ##copy from vllm
308356 scale = torch .clip (scale , min = min_scaling_factor )
0 commit comments