Skip to content

Commit a799d14

Browse files
gzy19990617bukejiyursmallbluezeroRainsxjkmfa
authored
[Bugfix] Fix model accuracy in some ops (#3231)
* fix noaux_tc op * fix * update * fix qk norm * fix linear for prequant loader * test * fix * fix * rm some print * fix noaux_tc op * test * Fix the confused enable_early_stop when only set early_stop_config (#3214) * fix the confused early_stop_config when only set early_stop_config * pre-commit * write a general method * Add ci case for min token and max token (#3229) Co-authored-by: xujing43 <[email protected]> * add some evil cases (#3240) * add repitation early stop cases * add repitation early stop cases * add bad cases * add bad cases * add evil cases * qwen3_moe (#3084) * [Feature] support seed parameter (#3161) * support seed * fix * add SamplingMetadata seed test * The next_tokens values are inconsistent! * add air and rejection seed test * fix * add SamplingParams seed test * fix seed=0 * Default to defualt * fix * fix args_utils * fix review * fix review * fix * fix * add xpu,gcu,iluvatar support seed * fix * 【Fix Bug】 修复 fa3 支持集中式bug (#3235) * fix fa3 集中式bug * 增加qknorm参数 * fix qk norm * fix * update * fix linear for prequant loader * fix * fix * rm some print * fix * fix moe init weight&scale * fix moe init weight&scale --------- Co-authored-by: bukejiyu <[email protected]> Co-authored-by: yuanxiaolan <[email protected]> Co-authored-by: Zero Rains <[email protected]> Co-authored-by: xjkmfa <[email protected]> Co-authored-by: xujing43 <[email protected]> Co-authored-by: Divano <[email protected]> Co-authored-by: bukejiyu <[email protected]> Co-authored-by: lizexu123 <[email protected]> Co-authored-by: yangjianfengo1 <[email protected]> Co-authored-by: qingqing01 <[email protected]>
1 parent ce1f353 commit a799d14

File tree

8 files changed

+62
-31
lines changed

8 files changed

+62
-31
lines changed

custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,14 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
5656
LoadEmbT cos_emb_vec;
5757
LoadEmbT sin_emb_vec;
5858

59-
int64_t global_warp_idx = blockIdx.x * blockDim.x + threadIdx.x;
60-
int64_t all_warp_num = gridDim.x * blockDim.x;
59+
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
60+
int64_t all_warp_num = gridDim.x * blockDim.y;
6161
int64_t all_head_dim = elem_cnt / head_size;
6262

6363
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
64-
// const int64_t offset = 2 * hidden_size;
6564
const int half_head_size = head_size / 2;
6665
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_dim; gloabl_hi += all_warp_num) {
67-
int64_t linear_index = gloabl_hi * head_size + threadIdx.y * VecSize;
66+
int64_t linear_index = gloabl_hi * head_size + threadIdx.x * VecSize;
6867
const int ori_bi = linear_index / hidden_size;
6968
const int bias = linear_index % hidden_size;
7069
const int hi = bias / head_size; // q + k + v
@@ -122,13 +121,13 @@ __global__ void append_decode_cache_T_rope_qk_norm_kernel(
122121
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
123122
LoadT q_norm_vec, k_norm_vec;
124123
if (hi < num_heads) { // q
125-
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
124+
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
126125
#pragma unroll
127126
for (int i = 0; i < VecSize; i++) {
128127
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
129128
}
130129
} else { // k
131-
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
130+
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
132131
for (int i = 0; i < VecSize; i++) {
133132
out_vec[i] = static_cast<T>(static_cast<float>(out_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
134133
}

custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_kernel.cu

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,14 @@ void append_decode_cache_rope_qk_norm(const QKV_TYPE* qkv,
4545
const uint32_t elem_nums =
4646
use_neox_style ? bsz * (num_heads + 2 * kv_num_heads) * dim_head / 2
4747
: bsz * (num_heads + 2 * kv_num_heads) * dim_head;
48-
assert(dim_head == 128 && "dim_head must be 128");
4948
constexpr int HEAD_DIM = 128;
5049

5150
constexpr int PackSize = HEAD_DIM / kWarpSize;
5251
const int pack_num = elem_nums / PackSize;
5352
const int blocksize = 128;
5453
int grid_size = 1;
5554
GetNumBlocks<128>(pack_num, &grid_size);
56-
dim3 block_dim(blocksize / kWarpSize, kWarpSize, 1);
55+
dim3 block_dim(kWarpSize, blocksize / kWarpSize, 1);
5756
append_decode_cache_T_rope_qk_norm_kernel<T, PackSize>
5857
<<<grid_size, block_dim, 0, stream>>>(reinterpret_cast<const T*>(qkv),
5958
key_cache,

custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -432,13 +432,13 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
432432
LoadT src_vec;
433433
LoadEmbT cos_emb_vec;
434434
LoadEmbT sin_emb_vec;
435-
int64_t global_warp_idx = blockDim.x * blockIdx.x + threadIdx.x;
436-
int64_t all_warp_num = gridDim.x * blockDim.x;
435+
int64_t global_warp_idx = blockDim.y * blockIdx.x + threadIdx.y;
436+
int64_t all_warp_num = gridDim.x * blockDim.y;
437437
const int half_lastdim = last_dim / 2;
438438
const int offset = (q_num_head + kv_num_head) * last_dim;
439439
const int all_head_num = elem_cnt / last_dim;
440440
for (int gloabl_hi = global_warp_idx; gloabl_hi < all_head_num; gloabl_hi += all_warp_num) {
441-
int64_t linear_index = gloabl_hi * last_dim + threadIdx.y * VecSize;
441+
int64_t linear_index = gloabl_hi * last_dim + threadIdx.x * VecSize;
442442
const int token_idx = linear_index / offset;
443443
const int ori_bi = batch_id_per_token[token_idx];
444444
if (seq_lens[ori_bi] == 0) continue;
@@ -478,13 +478,13 @@ __global__ void GQAVariableLengthRotaryQKNormKernel(
478478
float row_inv_var = Rsqrt(row_variance + rms_norm_eps);
479479
LoadT q_norm_vec, k_norm_vec;
480480
if (hi < q_num_head) {
481-
Load<T, VecSize>(&q_norm_weight[threadIdx.y * VecSize], &q_norm_vec);
481+
Load<T, VecSize>(&q_norm_weight[threadIdx.x * VecSize], &q_norm_vec);
482482
#pragma unroll
483483
for (int i = 0; i < VecSize; i++) {
484484
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(q_norm_vec[i]));
485485
}
486486
} else {
487-
Load<T, VecSize>(&k_norm_weight[threadIdx.y * VecSize], &k_norm_vec);
487+
Load<T, VecSize>(&k_norm_weight[threadIdx.x * VecSize], &k_norm_vec);
488488
for (int i = 0; i < VecSize; i++) {
489489
src_vec[i] = static_cast<T>(static_cast<float>(src_vec[i]) * row_inv_var * static_cast<float>(k_norm_vec[i]));
490490
}
@@ -1690,13 +1690,13 @@ void gqa_rotary_qk_norm_variable(
16901690
const int blocksize = 128;
16911691
int grid_size = 1;
16921692
GetNumBlocks<128>(pack_num, &grid_size);
1693-
dim3 Blocks(grid_size/kWarpSize, kWarpSize, 1);
1693+
dim3 Block_Size(kWarpSize, blocksize/kWarpSize, 1);
16941694

16951695
const float *cos_emb = rotary_emb;
16961696
const float *sin_emb = rotary_emb + input_output_len * dim_head / 2;
16971697

16981698
GQAVariableLengthRotaryQKNormKernel<T, PackSize>
1699-
<<<grid_size, Blocks, 0, stream>>>(
1699+
<<<grid_size, Block_Size, 0, stream>>>(
17001700
reinterpret_cast<const T *>(qkv_input),
17011701
cos_emb,
17021702
sin_emb,

custom_ops/gpu_ops/append_attn/utils.cuh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -430,6 +430,9 @@ __forceinline__ __host__ __device__ void vec_cast<nv_bfloat16, float>(
430430
} else if (group_size == 12) { \
431431
constexpr size_t GROUP_SIZE = 12; \
432432
__VA_ARGS__ \
433+
} else if (group_size == 14) { \
434+
constexpr size_t GROUP_SIZE = 14; \
435+
__VA_ARGS__ \
433436
} else if (group_size == 16) { \
434437
constexpr size_t GROUP_SIZE = 16; \
435438
__VA_ARGS__ \

custom_ops/gpu_ops/noaux_tc.cu

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,20 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
2828
int topk,
2929
float routed_scaling_factor) {
3030
auto input_shape = scores_with_bias.shape();
31+
PD_CHECK(input_shape.size() == 2);
3132
int64_t num_tokens = input_shape[0];
3233
int64_t num_experts = input_shape[1];
3334
auto input_type = scores_with_bias.dtype();
3435
auto place = scores_with_bias.place();
3536
auto group_scores = paddle::empty({num_tokens, n_group}, input_type, place);
3637
auto topk_values = paddle::empty({num_tokens, topk}, input_type, place);
37-
auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT32, place);
38+
auto topk_indices = paddle::empty({num_tokens, topk}, paddle::DataType::INT64, place);
3839
auto stream = scores_with_bias.stream();
3940

40-
invokeNoAuxTc<float, int32_t>(reinterpret_cast<float*>(scores.data<float>()),
41+
invokeNoAuxTc<float, int64_t>(reinterpret_cast<float*>(scores.data<float>()),
4142
reinterpret_cast<float*>(group_scores.data<float>()),
4243
reinterpret_cast<float*>(topk_values.data<float>()),
43-
reinterpret_cast<int32_t*>(topk_indices.data<int32_t>()),
44+
reinterpret_cast<int64_t*>(topk_indices.data<int64_t>()),
4445
reinterpret_cast<float*>(scores_with_bias.data<float>()),
4546
num_tokens,
4647
num_experts,
@@ -56,7 +57,7 @@ std::vector<paddle::Tensor> NoauxTc(paddle::Tensor& scores,
5657
std::vector<paddle::DataType> NoauxTcInferDtype(
5758
const paddle::DataType& scores_dtype,
5859
const paddle::DataType& scores_with_bias_dtype) {
59-
return {scores_dtype, scores_dtype, paddle::DataType::INT32};
60+
return {scores_dtype, scores_dtype, paddle::DataType::INT64};
6061
}
6162

6263
std::vector<std::vector<int64_t>> NoauxTcInferShape(
@@ -71,7 +72,7 @@ std::vector<std::vector<int64_t>> NoauxTcInferShape(
7172

7273
PD_BUILD_STATIC_OP(noaux_tc)
7374
.Inputs({"scores", "scores_with_bias"})
74-
.Outputs({"output_tensor"})
75+
.Outputs({"output_tensor", "topk_values", "topk_indices"})
7576
.Attrs({"n_group: int",
7677
"topk_group: int",
7778
"topk:int",

fastdeploy/model_executor/layers/moe/ep.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def get_moe_scores(
4949
compute moe scores using e_score_correction_bias.
5050
"""
5151
scores = paddle.nn.functional.sigmoid(gating_output)
52-
scores_with_bias = scores + e_score_correction_bias.unsqueeze(0)
52+
scores_with_bias = scores + e_score_correction_bias
5353
scores, topk_values, topk_idx = noaux_tc(
5454
scores,
5555
scores_with_bias,

fastdeploy/model_executor/layers/moe/fused_moe_deepgemm_backend.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -312,13 +312,26 @@ def apply_tp(
312312
below is TP compute method.
313313
"""
314314
gate_out = gate(x.cast("float32"))
315-
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
316-
gate_out,
317-
layer.gate_correction_bias,
318-
layer.top_k,
319-
True, # apply_norm_weight
320-
False,
321-
)
315+
316+
if layer.topk_method == "noaux_tc":
317+
from .ep import get_moe_scores
318+
319+
_, topk_weights, topk_ids = get_moe_scores(
320+
gate_out,
321+
layer.n_group,
322+
layer.topk_group,
323+
layer.top_k,
324+
layer.routed_scaling_factor,
325+
layer.gate_correction_bias,
326+
)
327+
else:
328+
topk_ids, topk_weights = fastdeploy.model_executor.ops.gpu.moe_topk_select(
329+
gate_out,
330+
layer.gate_correction_bias,
331+
layer.top_k,
332+
True, # apply_norm_weight
333+
False,
334+
)
322335

323336
tmp = count_tokens_per_expert_func(topk_ids, layer.num_experts)
324337

fastdeploy/model_executor/layers/moe/moe.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def init_moe_weights(self):
285285
dtype="float32",
286286
)
287287
up_gate_proj_output_dim = self.moe_intermediate_size * 2
288-
if self.moe_quant_type in ["fp8", "wint8"]:
288+
if self.moe_quant_type in ["block_wise_fp8", "wint8"]:
289289
up_gate_proj_weight_shape = [
290290
self.num_local_experts,
291291
up_gate_proj_output_dim,
@@ -309,9 +309,10 @@ def init_moe_weights(self):
309309
]
310310

311311
# Create parameters
312-
if self.moe_quant_type == "fp8":
312+
if self.moe_quant_type == "block_wise_fp8":
313313
# (TODO:gaoziyuan)
314-
pass
314+
self.weight_dtype = "float8_e4m3fn"
315+
self.init_block_wise_fp8_scale()
315316
elif self.moe_quant_type == "wint8":
316317
self.weight_dtype = "int8"
317318
self.init_weight_only_scale()
@@ -342,6 +343,21 @@ def init_weight_only_scale(self):
342343
dtype=self._dtype,
343344
)
344345

346+
def init_block_wise_fp8_scale(self):
347+
"""
348+
Initialize the weight scale.
349+
"""
350+
self.up_gate_proj_weight_scale = self.create_parameter(
351+
shape=[self.num_local_experts, self.moe_intermediate_size * 2 // 128, self.hidden_size // 128],
352+
dtype="float32",
353+
is_bias=False,
354+
)
355+
self.down_proj_weight_scale = self.create_parameter(
356+
shape=[self.num_local_experts, self.hidden_size // 128, self.moe_intermediate_size // 128],
357+
dtype="float32",
358+
is_bias=False,
359+
)
360+
345361
def load_experts_weight(
346362
self,
347363
state_dict: dict,

0 commit comments

Comments
 (0)