Skip to content

Commit 283da92

Browse files
fix ep lm head (#3244)
Co-authored-by: yuanxiaolan <[email protected]>
1 parent f516421 commit 283da92

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

fastdeploy/model_executor/layers/lm_head.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,7 @@ def load_state_dict(self, state_dict: Dict[str, paddle.Tensor | np.ndarray]):
118118
if self.use_ep:
119119
self.weight.set_value(get_tensor(state_dict.pop(self.weight_key)).astype(paddle.get_default_dtype()))
120120
if self.bias_key is not None:
121-
self.bias.set_value(
122-
get_tensor(state_dict.pop(self.linear_bias_key)).astype(paddle.get_default_dtype())
123-
)
121+
self.bias.set_value(get_tensor(state_dict.pop(self.bias_key)).astype(paddle.get_default_dtype()))
124122
else:
125123
if self.tie_word_embeddings:
126124
self.linear.weight.set_value(
@@ -148,7 +146,7 @@ def forward(self, input: paddle.Tensor) -> paddle.Tensor:
148146
"""
149147
logits = input
150148
if self.use_ep:
151-
if self.linear_bias_key is None:
149+
if self.bias_key is None:
152150
logits = paddle.matmul(logits, self.weight)
153151
else:
154152
logits = paddle.incubate.nn.functional.fused_linear(logits, self.weight, self.bias)

0 commit comments

Comments
 (0)