|
14 | 14 | from functools import partial |
15 | 15 | from lightllm.utils.log_utils import init_logger |
16 | 16 | from lightllm.utils.dist_utils import get_global_world_size |
| 17 | +from lightllm.distributed.communication_op import all_gather_into_tensor |
17 | 18 |
|
18 | 19 | logger = init_logger(__name__) |
19 | 20 |
|
@@ -82,6 +83,48 @@ def _get_qkv( |
82 | 83 | ) |
83 | 84 | return q, cache_kv |
84 | 85 |
|
| 86 | + def _tpsp_get_qkv( |
| 87 | + self, |
| 88 | + input: torch.Tensor, |
| 89 | + cache_kv, |
| 90 | + infer_state: LlamaInferStateInfo, |
| 91 | + layer_weight: Qwen3MOETransformerLayerWeight, |
| 92 | + ) -> torch.Tensor: |
| 93 | + if self.tp_world_size_ > 1: |
| 94 | + sp_token_num, hidden_dim = input.shape |
| 95 | + gather_input = self.alloc_tensor( |
| 96 | + (sp_token_num * self.tp_world_size_, hidden_dim), dtype=input.dtype, device=input.device |
| 97 | + ) |
| 98 | + all_gather_into_tensor(gather_input, input, group=infer_state.dist_group, async_op=False) |
| 99 | + input = gather_input[0 : len(infer_state.position_cos), :] |
| 100 | + |
| 101 | + input = input.view(-1, self.embed_dim_) |
| 102 | + q = layer_weight.q_proj.mm(input) |
| 103 | + cache_kv = layer_weight.kv_proj.mm( |
| 104 | + input, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_) |
| 105 | + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) |
| 106 | + |
| 107 | + rmsnorm_forward( |
| 108 | + q.view(-1, self.head_dim_), |
| 109 | + weight=layer_weight.q_norm_weight_.weight, |
| 110 | + eps=self.eps_, |
| 111 | + out=q.view(-1, self.head_dim_), |
| 112 | + ) |
| 113 | + |
| 114 | + cache_kv[:, : self.tp_k_head_num_, :] = rmsnorm_forward( |
| 115 | + cache_kv[:, : self.tp_k_head_num_, :].reshape(-1, cache_kv.shape[-1]), |
| 116 | + weight=layer_weight.k_norm_weight_.weight, |
| 117 | + eps=self.eps_, |
| 118 | + ).view(-1, self.tp_k_head_num_, cache_kv.shape[-1]) |
| 119 | + |
| 120 | + rotary_emb_fwd( |
| 121 | + q.view(-1, self.tp_q_head_num_, self.head_dim_), |
| 122 | + cache_kv[:, : self.tp_k_head_num_, :], |
| 123 | + infer_state.position_cos, |
| 124 | + infer_state.position_sin, |
| 125 | + ) |
| 126 | + return q, cache_kv |
| 127 | + |
85 | 128 | def _moe_ffn( |
86 | 129 | self, input, infer_state: LlamaInferStateInfo, layer_weight: Qwen3MOETransformerLayerWeight |
87 | 130 | ) -> torch.Tensor: |
|
0 commit comments