Skip to content

Commit 1e9a8e8

Browse files
fix lm head bias (#3185)
Co-authored-by: yuanxiaolan <[email protected]>
1 parent f5c64a0 commit 1e9a8e8

File tree

1 file changed

+15
-1
lines changed

1 file changed

+15
-1
lines changed

fastdeploy/model_executor/layers/lm_head.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,13 @@ def __init__(
7272
dtype=paddle.get_default_dtype(),
7373
is_bias=False,
7474
)
75+
if self.bias_key is not None:
76+
self.bias = self.create_parameter(
77+
shape=[num_embeddings],
78+
dtype=paddle.get_default_dtype(),
79+
is_bias=True,
80+
)
81+
7582
else:
7683
if self.column_cut:
7784
need_gather = True
@@ -107,6 +114,10 @@ def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
107114

108115
if self.use_ep:
109116
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
117+
if self.bias_key is not None:
118+
self.bias.set_value(
119+
get_tensor(state_dict.pop(self.linear_bias_key)).astype(paddle.get_default_dtype())
120+
)
110121
else:
111122
if self.tie_word_embeddings:
112123
self.linear.weight.set_value(
@@ -134,7 +145,10 @@ def forward(self, input: paddle.Tensor) -> paddle.Tensor:
134145
"""
135146
logits = input
136147
if self.use_ep:
137-
logits = paddle.matmul(logits, self.weight)
148+
if self.linear_bias_key is None:
149+
logits = paddle.matmul(logits, self.weight)
150+
else:
151+
logits = paddle.incubate.nn.functional.fused_linear(logits, self.weight, self.bias)
138152
else:
139153
logits = self.linear(logits)
140154
return logits

0 commit comments

Comments
 (0)