Skip to content

Commit d40a104

Browse files
[Feature] support rl_tp_degree (#3934)
* [Feature] support rl_tp_degree * add rl_tp_degree in lmhead * add rl_tp_degree in bias * fix split_axis=0 in bias * fix split_axis in weight * fix bias rl_tp_degree * fix bias rl_tp_degree * change attr to dict --------- Co-authored-by: Jiang-Jia-Jun <[email protected]>
1 parent fa23692 commit d40a104

File tree

3 files changed

+31
-1
lines changed

3 files changed

+31
-1
lines changed

fastdeploy/model_executor/layers/embeddings.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,11 @@ def __init__(
7777
)
7878
if self.world_size > 1:
7979
set_weight_attrs(self.embeddings.weight, {"output_dim": False})
80+
set_weight_attrs(
81+
self.embeddings.weight,
82+
{"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}},
83+
)
84+
8085
else:
8186
# column cut embedding
8287
self.embeddings = nn.Embedding(

fastdeploy/model_executor/layers/linear.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,11 +356,21 @@ def __init__(
356356
)
357357

358358
if self.nranks > 0:
359+
_set_var_distributed(self.weight, split_axis=-1)
359360
if self.with_bias:
360361
# col parallel
361-
_set_var_distributed(self.bias, split_axis=1)
362+
_set_var_distributed(self.bias, split_axis=0)
362363
set_weight_attrs(self.bias, {"output_dim": True})
363364

365+
# set_rl_tp_degree
366+
set_weight_attrs(
367+
self.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
368+
)
369+
if self.with_bias:
370+
set_weight_attrs(
371+
self.bias, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
372+
)
373+
364374

365375
class MergedColumnParallelLinear(ColumnParallelLinear):
366376
"""
@@ -743,6 +753,7 @@ def __init__(
743753
model_format=fd_config.model_config.model_format,
744754
)
745755
if self.nranks > 0:
756+
_set_var_distributed(self.weight, split_axis=0)
746757
if self.with_bias:
747758
# col parallel
748759
_set_var_distributed(self.bias, split_axis=0)
@@ -755,6 +766,11 @@ def __init__(
755766

756767
self.reduce_results = reduce_results
757768

769+
# set_rl_tp_degree
770+
set_weight_attrs(
771+
self.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
772+
)
773+
758774
def forward_cuda(self, x: paddle.Tensor) -> paddle.Tensor:
759775
if self.fd_config.quant_config:
760776
out = self.quant_method.apply(self, x)

fastdeploy/model_executor/layers/lm_head.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,12 @@ def __init__(
9494
"model_format": self.fd_config.model_config.model_format,
9595
},
9696
)
97+
if self.bias_key is not None:
98+
set_weight_attrs(
99+
self.linear.bias,
100+
{"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}},
101+
)
102+
97103
if self.nranks > 1:
98104
set_weight_attrs(self.linear.weight, {"output_dim": True})
99105
else:
@@ -116,6 +122,9 @@ def __init__(
116122

117123
if self.nranks > 1:
118124
set_weight_attrs(self.linear.weight, {"output_dim": False})
125+
set_weight_attrs(
126+
self.linear.weight, {"rl_need_attr": {"rl_tp_degree": fd_config.parallel_config.tensor_parallel_size}}
127+
)
119128

120129
def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
121130
"""

0 commit comments

Comments
 (0)