Skip to content

Commit 9408e66

Browse files
authored
[bugfix]fix blockwisefp8 and all_reduce (#3243)
* fix * update * fix linear for prequant loader
1 parent 3a15e0c commit 9408e66

File tree

4 files changed

+37
-24
lines changed

4 files changed

+37
-24
lines changed

fastdeploy/model_executor/layers/embeddings.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,8 @@ def __init__(
8181
initializer=nn.initializer.Normal(mean=0.0, std=self.initializer_range),
8282
),
8383
)
84-
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
84+
if self.world_size > 1:
85+
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
8586
else:
8687
# column cut embedding
8788
self.embeddings = nn.Embedding(
@@ -91,7 +92,8 @@ def __init__(
9192

9293
self.embeddings.weight.is_distributed = True
9394
self.embeddings.weight.split_axis = 1
94-
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
95+
if self.world_size > 1:
96+
set_weight_attrs(self.embeddings.weight, {"output_dim": True})
9597

9698
self.prefix = prefix
9799
self.dropout = nn.Dropout(self.hidden_dropout_prob)

fastdeploy/model_executor/layers/linear.py

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ class UnquantizedLinearMethod(QuantMethodBase):
3737
def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
3838
"""
3939
extra_weight_attrs is a dictionary that may include parameters like:
40-
- split_axis: specifies which axis to split the weight tensor on (for distributed weight partitioning)
4140
- output_dim: determines whether the split is applied along the output dimension (rows) or input dimension (columns)
4241
- weight_loader: a callable or method responsible for loading the weight data
4342
"""
@@ -51,9 +50,7 @@ def create_weights(self, layer: nn.Layer, **extra_weight_attrs):
5150
layer.weight,
5251
{"weight_loader": extra_weight_attrs.get("weight_loader", default_weight_loader(layer.fd_config))},
5352
)
54-
if hasattr(layer, "nranks") and layer.nranks > 0:
55-
split_axis = extra_weight_attrs.get("split_axis")
56-
_set_var_distributed(layer.weight, split_axis=split_axis)
53+
if hasattr(layer, "nranks") and layer.nranks > 1:
5754
set_weight_attrs(layer.weight, {"output_dim": extra_weight_attrs.get("output_dim")})
5855

5956
def process_loaded_weights(self, layer, weights) -> None:
@@ -125,6 +122,10 @@ def __init__(
125122
# key
126123
if weight_key:
127124
self.weight_key = f"{prefix}.{weight_key}"
125+
elif fd_config.model_config.is_quantized and not skip_quant:
126+
self.weight_key = f"{prefix}.quant_weight"
127+
self.weight_scale_key = f"{prefix}.weight_scale"
128+
self.act_scale_key = f"{prefix}.activation_scale"
128129
else:
129130
self.weight_key = f"{prefix}.weight"
130131
self.bias_key = f"{prefix}.bias"
@@ -173,7 +174,11 @@ def load_prequant_weight(self, state_dict: dict):
173174
Args:
174175
state_dict (dict): A dictionary containing the prequantized weights and scales.
175176
"""
176-
self.quant_method.process_prequanted_weights(self, state_dict)
177+
if isinstance(self.quant_method, UnquantizedLinearMethod):
178+
# for gate
179+
self.load_weight(state_dict)
180+
else:
181+
self.quant_method.process_prequanted_weights(self, state_dict)
177182

178183
def load_weight(self, state_dict: dict):
179184
"""
@@ -333,18 +338,18 @@ def __init__(
333338
assert self.quant_method is not None
334339
self.quant_method.create_weights(
335340
self,
336-
split_axis=1,
337341
output_dim=True,
338342
weight_loader=(
339343
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
340344
),
341345
)
342-
343-
if self.with_bias:
344-
if self.nranks > 0:
346+
if self.nranks > 0:
347+
_set_var_distributed(self.weight, split_axis=1)
348+
if self.with_bias:
345349
# col parallel
346350
_set_var_distributed(self.bias, split_axis=1)
347-
set_weight_attrs(self.bias, {"output_dim": True})
351+
if self.nranks > 1:
352+
set_weight_attrs(self.bias, {"output_dim": True})
348353

349354

350355
class MergedColumnParallelLinear(ColumnParallelLinear):
@@ -669,15 +674,19 @@ def __init__(
669674
self.weight_loader if hasattr(self, "weight_loader") else default_weight_loader(self.fd_config)
670675
),
671676
)
677+
if self.nranks > 0:
678+
_set_var_distributed(self.weight, split_axis=0)
679+
if self.with_bias:
680+
# col parallel
681+
_set_var_distributed(self.bias, split_axis=0)
682+
if self.nranks > 1:
683+
set_weight_attrs(
684+
self.bias,
685+
{
686+
"output_dim": False,
687+
},
688+
)
672689

673-
if self.with_bias:
674-
_set_var_distributed(self.bias, split_axis=0)
675-
set_weight_attrs(
676-
self.bias,
677-
{
678-
"output_dim": False,
679-
},
680-
)
681690
self.reduce_results = reduce_results
682691

683692
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:

fastdeploy/model_executor/layers/lm_head.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ def __init__(
6060
self.bias_key: Optional[str] = None
6161
self.use_ep: bool = fd_config.parallel_config.use_ep
6262
self.column_cut = True
63+
self.nranks = fd_config.parallel_config.tensor_parallel_size
6364

6465
ColumnParallelLinear = fleet.meta_parallel.ColumnParallelLinear
6566
RowParallelLinear = fleet.meta_parallel.RowParallelLinear
@@ -91,7 +92,8 @@ def __init__(
9192
gather_output=need_gather,
9293
fuse_matmul_bias=False, # False diff更小
9394
)
94-
set_weight_attrs(self.linear.weight, {"output_dim": True})
95+
if self.nranks > 1:
96+
set_weight_attrs(self.linear.weight, {"output_dim": True})
9597
else:
9698
self.linear = RowParallelLinear(
9799
embedding_dim,
@@ -102,7 +104,8 @@ def __init__(
102104
input_is_parallel=False,
103105
fuse_matmul_bias=False, # False diff更小
104106
)
105-
set_weight_attrs(self.linear.weight, {"output_dim": False})
107+
if self.nranks > 1:
108+
set_weight_attrs(self.linear.weight, {"output_dim": False})
106109

107110
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
108111
"""

fastdeploy/model_executor/layers/quantization/block_wise_fp8.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383

8484
def create_weights(self, layer, **extra_weight_attrs):
8585
layer.weight_shape.reverse()
86-
86+
layer.weight_dtype = "float8_e4m3fn"
8787
layer.weight = layer.create_parameter(
8888
shape=layer.weight_shape,
8989
dtype=layer.weight_dtype,
@@ -101,7 +101,6 @@ def create_weights(self, layer, **extra_weight_attrs):
101101
dtype="float32",
102102
is_bias=False,
103103
)
104-
layer.weight_dtype = "float8_e4m3fn"
105104

106105
def process_loaded_weights(self, layer, weights) -> None:
107106
weight_tensor = weights.transpose([1, 0])

0 commit comments

Comments
 (0)