Skip to content

Commit 7c91907

Browse files
authored
[Metax] support cutlass moe & optimize flash attention (#4208)
1 parent 2b2b645 commit 7c91907

20 files changed

+2786
-103
lines changed

custom_ops/gpu_ops/helper.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
#pragma once
1616

17+
#include <cuda_fp8.h>
18+
1719
#ifndef PADDLE_WITH_COREX
1820
#include "glog/logging.h"
1921
#endif

custom_ops/metax_ops/fused_moe.cu

Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
16+
#pragma once
17+
18+
#include "helper.h"
19+
#include "mc_fused_moe_helper.h"
20+
#include "fused_moe_op.h"
21+
22+
__global__ void compute_total_rows_before_expert_kernel(
23+
int* sorted_experts,
24+
const int64_t sorted_experts_len,
25+
const int64_t num_experts,
26+
int32_t* total_rows_before_expert) {
27+
const int expert = blockIdx.x * blockDim.x + threadIdx.x;
28+
if (expert >= num_experts) return;
29+
30+
total_rows_before_expert[expert] =
31+
find_total_elts_leq_target(sorted_experts, sorted_experts_len, expert);
32+
}
33+
34+
void compute_total_rows_before_expert(int* sorted_indices,
35+
const int64_t total_indices,
36+
const int64_t num_experts,
37+
int32_t* total_rows_before_expert,
38+
cudaStream_t stream) {
39+
const int threads = std::min(int64_t(1024), num_experts);
40+
const int blocks = (num_experts + threads - 1) / threads;
41+
42+
compute_total_rows_before_expert_kernel<<<blocks, threads, 0, stream>>>(
43+
sorted_indices, total_indices, num_experts, total_rows_before_expert);
44+
}
45+
46+
template <paddle::DataType T, typename ElementA, typename ElementB, typename ElementC>
47+
void FusedMoeKernel(const paddle::Tensor& input,
48+
const paddle::Tensor& gate_weight,
49+
const paddle::Tensor& ffn1_weight,
50+
const paddle::optional<paddle::Tensor>& ffn1_scale,
51+
const paddle::optional<paddle::Tensor>& ffn1_bias,
52+
const paddle::Tensor& ffn2_weight,
53+
const paddle::optional<paddle::Tensor>& ffn2_scale,
54+
const paddle::optional<paddle::Tensor>& ffn2_bias,
55+
const std::string& quant_method,
56+
const int moe_topk,
57+
const bool group_moe,
58+
const bool norm_topk_prob,
59+
paddle::Tensor* output) {
60+
typedef PDTraits<T> traits_;
61+
typedef typename traits_::DataType DataType_;
62+
typedef typename traits_::data_t data_t;
63+
64+
auto* output_data = output->data<data_t>();
65+
66+
auto moe_compute = McMoeHelper<data_t, ElementA, ElementB, ElementC>(quant_method);
67+
68+
moe_compute.computeFFN(
69+
&input,
70+
&gate_weight,
71+
&ffn1_weight,
72+
ffn1_scale ? ffn1_scale.get_ptr() : nullptr,
73+
ffn1_bias ? ffn1_bias.get_ptr() : nullptr,
74+
&ffn2_weight,
75+
ffn2_scale ? ffn2_scale.get_ptr() : nullptr,
76+
ffn2_bias ? ffn2_bias.get_ptr() : nullptr,
77+
nullptr,
78+
moe_topk,
79+
group_moe,
80+
norm_topk_prob,
81+
1.0, // ComputeFFN
82+
"ffn",
83+
output);
84+
}
85+
86+
87+
std::vector<paddle::Tensor> FusedExpertMoe(
88+
const paddle::Tensor& input,
89+
const paddle::Tensor& gate_weight,
90+
const paddle::Tensor& ffn1_weight,
91+
const paddle::Tensor& ffn2_weight,
92+
const paddle::optional<paddle::Tensor>& ffn1_bias,
93+
const paddle::optional<paddle::Tensor>& ffn1_scale,
94+
const paddle::optional<paddle::Tensor>& ffn2_bias,
95+
const paddle::optional<paddle::Tensor>& ffn2_scale,
96+
const std::string& quant_method,
97+
const int moe_topk,
98+
const bool norm_topk_prob,
99+
const bool group_moe) {
100+
const auto input_type = input.dtype();
101+
auto output = paddle::empty_like(input);
102+
103+
switch (input_type) {
104+
case paddle::DataType::BFLOAT16:
105+
FusedMoeKernel<paddle::DataType::BFLOAT16, maca_bfloat16, int8_t, maca_bfloat16>(input,
106+
gate_weight,
107+
ffn1_weight,
108+
ffn1_scale,
109+
ffn1_bias,
110+
ffn2_weight,
111+
ffn2_scale,
112+
ffn2_bias,
113+
quant_method,
114+
moe_topk,
115+
group_moe,
116+
norm_topk_prob,
117+
&output);
118+
break;
119+
// case paddle::DataType::FLOAT16:
120+
// FusedMoeKernel<paddle::DataType::FLOAT16>(input,
121+
// gate_weight,
122+
// ffn1_weight,
123+
// ffn1_scale,
124+
// ffn1_bias,
125+
// ffn2_weight,
126+
// ffn2_scale,
127+
// ffn2_bias,
128+
// quant_method,
129+
// moe_topk,
130+
// group_moe,
131+
// norm_topk_prob,
132+
// &output);
133+
// break;
134+
default:
135+
PD_THROW("Only support bf16 for FusedMoeKernel");
136+
}
137+
return {output};
138+
}
139+
140+
std::vector<std::vector<int64_t>> FusedExpertMoeInferShape(
141+
const std::vector<int64_t>& input_shape,
142+
const std::vector<int64_t>& gate_weight_shape,
143+
const std::vector<int64_t>& ffn1_weight_shape,
144+
const std::vector<int64_t>& ffn2_weight_shape,
145+
const paddle::optional<std::vector<int64_t>>& ffn1_bias_shape,
146+
const paddle::optional<std::vector<int64_t>>& ffn1_scale_shape,
147+
const paddle::optional<std::vector<int64_t>>& ffn2_bias_shape,
148+
const paddle::optional<std::vector<int64_t>>& ffn2_scale_shape) {
149+
return {input_shape};
150+
}
151+
152+
std::vector<paddle::DataType> FusedExpertMoeInferDtype(
153+
const paddle::DataType& input_dtype,
154+
const paddle::DataType& gate_weight_dtype,
155+
const paddle::DataType& ffn1_weight_dtype,
156+
const paddle::DataType& ffn2_weight_dtype,
157+
const paddle::optional<paddle::DataType>& ffn1_bias_dtype,
158+
const paddle::optional<paddle::DataType>& ffn1_scale_dtype,
159+
const paddle::optional<paddle::DataType>& ffn2_bias_dtype,
160+
const paddle::optional<paddle::DataType>& ffn2_scale_dtype) {
161+
return {input_dtype};
162+
}
163+
164+
165+
PD_BUILD_OP(fused_expert_moe)
166+
.Inputs({"input",
167+
"gate_weight",
168+
"ffn1_weight",
169+
"ffn2_weight",
170+
paddle::Optional("ffn1_bias"),
171+
paddle::Optional("ffn1_scale"),
172+
paddle::Optional("ffn2_bias"),
173+
paddle::Optional("ffn2_scale")})
174+
.Outputs({"output"})
175+
.Attrs({"quant_method:std::string",
176+
"moe_topk:int",
177+
"norm_topk_prob:bool",
178+
"group_moe:bool"})
179+
.SetKernelFn(PD_KERNEL(FusedExpertMoe))
180+
.SetInferShapeFn(PD_INFER_SHAPE(FusedExpertMoeInferShape))
181+
.SetInferDtypeFn(PD_INFER_DTYPE(FusedExpertMoeInferDtype));
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include "cutlass_kernels/moe_gemm/fused_moe_gemm_kernels.h"
18+
#include "fused_moe_op.h"
19+
20+
using namespace phi;
21+
22+
template <typename T, int VecSize>
23+
__global__ void moe_token_type_ids_kernel(T *gating_output,
24+
const int *moe_token_type_ids_out,
25+
const int num_rows,
26+
const int num_experts,
27+
const int k) {
28+
const int moe_token_index = blockIdx.x * blockDim.x + threadIdx.x;
29+
30+
if (moe_token_index >= num_rows) {
31+
return;
32+
}
33+
34+
gating_output[moe_token_index * 2] =
35+
gating_output[moe_token_index * 2] +
36+
(moe_token_type_ids_out[moe_token_index]) * -1e10;
37+
gating_output[moe_token_index * 2 + 1] =
38+
gating_output[moe_token_index * 2 + 1] +
39+
(1 - moe_token_type_ids_out[moe_token_index]) * -1e10;
40+
}
41+
42+
template <typename T>
43+
void moe_token_type_ids_kernelLauncher(T *gating_output,
44+
const int *moe_token_type_ids_out,
45+
const int num_rows,
46+
const int num_experts,
47+
const int k,
48+
cudaStream_t stream) {
49+
const int blocks = num_rows * k / 512 + 1;
50+
const int threads = 512;
51+
moe_token_type_ids_kernel<T, 1><<<blocks, 512, 0, stream>>>(
52+
gating_output, moe_token_type_ids_out, num_rows, num_experts, k);
53+
}
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
/*
2+
* SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION &
3+
* AFFILIATES. All rights reserved. SPDX-License-Identifier: Apache-2.0
4+
*
5+
* Licensed under the Apache License, Version 2.0 (the "License");
6+
* you may not use this file except in compliance with the License.
7+
* You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
#pragma once
19+
#include <string>
20+
#include <sstream>
21+
#include "cub/cub.cuh"
22+
23+
static const float HALF_FLT_MAX = 65504.F;
24+
static const float HALF_FLT_MIN = -65504.F;
25+
static inline size_t AlignTo16(const size_t& input) {
26+
static constexpr int ALIGNMENT = 16;
27+
return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT);
28+
}
29+
30+
class CubKeyValueSorter {
31+
public:
32+
CubKeyValueSorter() : num_experts_(0), num_bits_(sizeof(int) * 8) {}
33+
34+
explicit CubKeyValueSorter(const int num_experts)
35+
: num_experts_(num_experts),
36+
num_bits_(static_cast<int>(log2(num_experts)) + 1) {}
37+
38+
void update_num_experts(const int num_experts) {
39+
num_experts_ = num_experts;
40+
num_bits_ = static_cast<int>(log2(num_experts)) + 1;
41+
}
42+
43+
size_t getWorkspaceSize(const size_t num_key_value_pairs,
44+
bool descending = false) {
45+
num_key_value_pairs_ = num_key_value_pairs;
46+
size_t required_storage = 0;
47+
int* null_int = nullptr;
48+
if (descending) {
49+
cub::DeviceRadixSort::SortPairsDescending(NULL,
50+
required_storage,
51+
null_int,
52+
null_int,
53+
null_int,
54+
null_int,
55+
num_key_value_pairs,
56+
0,
57+
32);
58+
} else {
59+
cub::DeviceRadixSort::SortPairs(NULL,
60+
required_storage,
61+
null_int,
62+
null_int,
63+
null_int,
64+
null_int,
65+
num_key_value_pairs,
66+
0,
67+
num_bits_);
68+
}
69+
return required_storage;
70+
}
71+
72+
template <typename KeyT>
73+
void run(void* workspace,
74+
const size_t workspace_size,
75+
const KeyT* keys_in,
76+
KeyT* keys_out,
77+
const int* values_in,
78+
int* values_out,
79+
const size_t num_key_value_pairs,
80+
bool descending,
81+
cudaStream_t stream) {
82+
size_t expected_ws_size = getWorkspaceSize(num_key_value_pairs);
83+
size_t actual_ws_size = workspace_size;
84+
85+
if (expected_ws_size > workspace_size) {
86+
std::stringstream err_ss;
87+
err_ss << "[Error][CubKeyValueSorter::run]\n";
88+
err_ss << "Error. The allocated workspace is too small to run this "
89+
"problem.\n";
90+
err_ss << "Expected workspace size of at least " << expected_ws_size
91+
<< " but got problem size " << workspace_size << "\n";
92+
throw std::runtime_error(err_ss.str());
93+
}
94+
if (descending) {
95+
cub::DeviceRadixSort::SortPairsDescending(workspace,
96+
actual_ws_size,
97+
keys_in,
98+
keys_out,
99+
values_in,
100+
values_out,
101+
num_key_value_pairs,
102+
0,
103+
32,
104+
stream);
105+
} else {
106+
cub::DeviceRadixSort::SortPairs(workspace,
107+
actual_ws_size,
108+
keys_in,
109+
keys_out,
110+
values_in,
111+
values_out,
112+
num_key_value_pairs,
113+
0,
114+
num_bits_,
115+
stream);
116+
}
117+
}
118+
119+
private:
120+
size_t num_key_value_pairs_;
121+
int num_experts_;
122+
int num_bits_;
123+
};

0 commit comments

Comments
 (0)