@@ -37,7 +37,6 @@ class UnquantizedLinearMethod(QuantMethodBase):
37
37
def create_weights (self , layer : nn .Layer , ** extra_weight_attrs ):
38
38
"""
39
39
extra_weight_attrs is a dictionary that may include parameters like:
40
- - split_axis: specifies which axis to split the weight tensor on (for distributed weight partitioning)
41
40
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
42
41
- weight_loader: a callable or method responsible for loading the weight data
43
42
"""
@@ -51,9 +50,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
51
50
layer .weight ,
52
51
{"weight_loader" : extra_weight_attrs .get ("weight_loader" , default_weight_loader (layer .fd_config ))},
53
52
)
54
- if hasattr (layer , "nranks" ) and layer .nranks > 0 :
55
- split_axis = extra_weight_attrs .get ("split_axis" )
56
- _set_var_distributed (layer .weight , split_axis = split_axis )
53
+ if hasattr (layer , "nranks" ) and layer .nranks > 1 :
57
54
set_weight_attrs (layer .weight , {"output_dim" : extra_weight_attrs .get ("output_dim" )})
58
55
59
56
def process_loaded_weights (self , layer , weights ) -> None :
@@ -125,6 +122,10 @@ def __init__(
125
122
# key
126
123
if weight_key :
127
124
self .weight_key = f"{ prefix } .{ weight_key } "
125
+ elif fd_config .model_config .is_quantized and not skip_quant :
126
+ self .weight_key = f"{ prefix } .quant_weight"
127
+ self .weight_scale_key = f"{ prefix } .weight_scale"
128
+ self .act_scale_key = f"{ prefix } .activation_scale"
128
129
else :
129
130
self .weight_key = f"{ prefix } .weight"
130
131
self .bias_key = f"{ prefix } .bias"
@@ -173,7 +174,11 @@ def load_prequant_weight(self, state_dict: dict):
173
174
Args:
174
175
state_dict (dict): A dictionary containing the prequantized weights and scales.
175
176
"""
176
- self .quant_method .process_prequanted_weights (self , state_dict )
177
+ if isinstance (self .quant_method , UnquantizedLinearMethod ):
178
+ # for gate
179
+ self .load_weight (state_dict )
180
+ else :
181
+ self .quant_method .process_prequanted_weights (self , state_dict )
177
182
178
183
def load_weight (self , state_dict : dict ):
179
184
"""
@@ -333,18 +338,18 @@ def __init__(
333
338
assert self .quant_method is not None
334
339
self .quant_method .create_weights (
335
340
self ,
336
- split_axis = 1 ,
337
341
output_dim = True ,
338
342
weight_loader = (
339
343
self .weight_loader if hasattr (self , "weight_loader" ) else default_weight_loader (self .fd_config )
340
344
),
341
345
)
342
-
343
- if self .with_bias :
344
- if self .nranks > 0 :
346
+ if self . nranks > 0 :
347
+ _set_var_distributed ( self .weight , split_axis = 1 )
348
+ if self .with_bias :
345
349
# col parallel
346
350
_set_var_distributed (self .bias , split_axis = 1 )
347
- set_weight_attrs (self .bias , {"output_dim" : True })
351
+ if self .nranks > 1 :
352
+ set_weight_attrs (self .bias , {"output_dim" : True })
348
353
349
354
350
355
class MergedColumnParallelLinear (ColumnParallelLinear ):
@@ -669,15 +674,19 @@ def __init__(
669
674
self .weight_loader if hasattr (self , "weight_loader" ) else default_weight_loader (self .fd_config )
670
675
),
671
676
)
677
+ if self .nranks > 0 :
678
+ _set_var_distributed (self .weight , split_axis = 0 )
679
+ if self .with_bias :
680
+ # col parallel
681
+ _set_var_distributed (self .bias , split_axis = 0 )
682
+ if self .nranks > 1 :
683
+ set_weight_attrs (
684
+ self .bias ,
685
+ {
686
+ "output_dim" : False ,
687
+ },
688
+ )
672
689
673
- if self .with_bias :
674
- _set_var_distributed (self .bias , split_axis = 0 )
675
- set_weight_attrs (
676
- self .bias ,
677
- {
678
- "output_dim" : False ,
679
- },
680
- )
681
690
self .reduce_results = reduce_results
682
691
683
692
def forward_cuda (self , x : paddle .Tensor ) -> paddle .Tensor :
0 commit comments