Skip to content

Commit 08cc488

Browse files
committed
fix deepseekv2 balance
1 parent 5d6fe53 commit 08cc488

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,11 @@ def _tpsp_get_qkv(
225225
infer_state.position_cos,
226226
infer_state.position_sin,
227227
)
228+
229+
if infer_state.need_dp_prefill_balance:
230+
q = infer_state._all_to_all_unbalance_get(data=q)
231+
cache_kv = infer_state._all_to_all_unbalance_get(data=cache_kv)
232+
228233
return q, cache_kv
229234

230235
def _get_o(
@@ -238,6 +243,9 @@ def _get_o(
238243
def _tpsp_get_o(
239244
self, input, infer_state: Deepseek2InferStateInfo, layer_weight: Deepseek2TransformerLayerWeight
240245
) -> torch.Tensor:
246+
if infer_state.need_dp_prefill_balance:
247+
input = infer_state._all_to_all_balance_get(data=input)
248+
241249
if input.shape[2] == self.kv_lora_rank:
242250
input = layer_weight.v_b_proj_.bmm(input.transpose(0, 1)).transpose(0, 1)
243251

0 commit comments

Comments
 (0)