11import torch
22from lightllm .common .basemodel .layer_weights .meta_weights .mm_weight .mm_weight import (
3- SingleMMWeightTpl ,
3+ MMWeightTpl ,
44 DeepGemmFP8W8A8B128MMWeight ,
55 AWQMMWeightTpl ,
66)
77from lightllm .common .quantization import Quantcfg
88from lightllm .utils .dist_utils import get_current_device_id
99from lightllm .common .quantization .quantize_method import QuantizationMethod
10- from typing import Dict , List , Optional
10+ from typing import Dict , List , Optional , Union
1111from .mm_slicer import ColSliceMixin , QuantizedRowSliceMixin , QuantizedColSliceMixin
1212
1313
14- class UnquantizedCOLMMWeight (SingleMMWeightTpl ):
14+ class UnquantizedCOLMMWeight (MMWeightTpl ):
1515 def __init__ (
1616 self ,
17- weight_name : str ,
17+ weight_names : Union [ str , List [ str ]] ,
1818 data_type : torch .dtype ,
19- bias_name : Optional [str ] = None ,
19+ bias_names : Optional [Union [ str , List [ str ]] ] = None ,
2020 quant_method : QuantizationMethod = None ,
2121 tp_rank : int = None ,
2222 tp_world_size : int = None ,
2323 ) -> None :
2424 super ().__init__ (
25- weight_name = weight_name ,
25+ weight_name = weight_names ,
2626 data_type = data_type ,
27- bias_name = bias_name ,
27+ bias_name = bias_names ,
2828 quant_method = quant_method ,
2929 tp_rank = tp_rank ,
3030 tp_world_size = tp_world_size ,
@@ -35,17 +35,17 @@ def __init__(
3535class DeepGemmFP8W8A8B128COLMMWeight (DeepGemmFP8W8A8B128MMWeight ):
3636 def __init__ (
3737 self ,
38- weight_name : str ,
38+ weight_names : Union [ str , List [ str ]] ,
3939 data_type : torch .dtype ,
40- bias_name : Optional [str ] = None ,
40+ bias_names : Optional [Union [ str , List [ str ]] ] = None ,
4141 quant_method : QuantizationMethod = None ,
4242 tp_rank : int = None ,
4343 tp_world_size : int = None ,
4444 ) -> None :
4545 super ().__init__ (
46- weight_name = weight_name ,
46+ weight_names = weight_names ,
4747 data_type = data_type ,
48- bias_name = bias_name ,
48+ bias_names = bias_names ,
4949 quant_method = quant_method ,
5050 tp_rank = tp_rank ,
5151 tp_world_size = tp_world_size ,
@@ -56,17 +56,17 @@ def __init__(
5656class AWQCOLMMWeight (AWQMMWeightTpl ):
5757 def __init__ (
5858 self ,
59- weight_name : str ,
59+ weight_names : Union [ str , List [ str ]] ,
6060 data_type : torch .dtype ,
61- bias_name : Optional [str ] = None ,
61+ bias_names : Optional [Union [ str , List [ str ]] ] = None ,
6262 quant_method : QuantizationMethod = None ,
6363 tp_rank : int = None ,
6464 tp_world_size : int = None ,
6565 ) -> None :
6666 super ().__init__ (
67- weight_name = weight_name ,
67+ weight_names = weight_names ,
6868 data_type = data_type ,
69- bias_name = bias_name ,
69+ bias_names = bias_names ,
7070 quant_method = quant_method ,
7171 tp_rank = tp_rank ,
7272 tp_world_size = tp_world_size ,
@@ -78,41 +78,22 @@ def __init__(
7878class AWQMARLINCOLMMWeight (AWQCOLMMWeight ):
7979 def __init__ (
8080 self ,
81- weight_name : str ,
81+ weight_names : Union [ str , List [ str ]] ,
8282 data_type : torch .dtype ,
83- bias_name : Optional [str ] = None ,
83+ bias_names : Optional [Union [ str , List [ str ]] ] = None ,
8484 quant_method : QuantizationMethod = None ,
8585 tp_rank : int = None ,
8686 tp_world_size : int = None ,
8787 ) -> None :
8888 super ().__init__ (
89- weight_name = weight_name ,
89+ weight_names = weight_names ,
9090 data_type = data_type ,
91- bias_name = bias_name ,
91+ bias_names = bias_names ,
9292 quant_method = quant_method ,
9393 tp_rank = tp_rank ,
9494 tp_world_size = tp_world_size ,
9595 )
9696
97- def _process_weight (self , weight : torch .Tensor ) -> torch .Tensor :
98- new_weight = self .quant_method ._process_weight_after_loading (weight .cuda (get_current_device_id ()))
99- self .mm_param .weight = new_weight
100- return
101-
102- def _process_weight_scale (self , weight_scale : torch .Tensor ) -> torch .Tensor :
103- new_weight_scale = self .quant_method ._process_weight_scale_after_loading (
104- weight_scale .cuda (get_current_device_id ()).to (self .data_type_ )
105- )
106- self .mm_param .weight_scale = new_weight_scale
107- return
108-
109- def _process_weight_zero_point (self , weight_zero_point : torch .Tensor ) -> torch .Tensor :
110- new_weight_zero_point = self .quant_method ._process_weight_zero_point_after_loading (
111- weight_zero_point .cuda (get_current_device_id ())
112- )
113- self .mm_param .weight_zero_point = new_weight_zero_point
114- return
115-
11697
11798COLMM_WEIGHT_CLS_MAP = {
11899 "deepgemm-fp8w8a8-b128" : DeepGemmFP8W8A8B128COLMMWeight ,
0 commit comments