|
| 1 | +// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +// |
| 3 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +// you may not use this file except in compliance with the License. |
| 5 | +// You may obtain a copy of the License at |
| 6 | +// |
| 7 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +// |
| 9 | +// Unless required by applicable law or agreed to in writing, software |
| 10 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +// See the License for the specific language governing permissions and |
| 13 | +// limitations under the License. |
| 14 | + |
| 15 | +#include "paddle/phi/backends/gpu/gpu_context.h" |
| 16 | +#include "paddle/phi/core/dense_tensor.h" |
| 17 | +#include "paddle/phi/core/kernel_registry.h" |
| 18 | +#include "paddle/phi/kernels/full_kernel.h" |
| 19 | +#include "paddle/phi/kernels/funcs/aligned_vector.h" |
| 20 | + |
| 21 | +namespace phi { |
| 22 | + |
| 23 | +template <typename T, typename MTP, int VecSize> |
| 24 | +__global__ void combine_no_weight_bwd_kernel(const int* scatter_index, |
| 25 | + const T* grad_y, |
| 26 | + T* grad_x, |
| 27 | + const int64_t k, |
| 28 | + const int64_t seqlen, |
| 29 | + const int64_t hidden_size) { |
| 30 | + using LoadT = phi::AlignedVector<T, VecSize>; |
| 31 | + LoadT grad_y_vec; |
| 32 | + int i = blockIdx.x; // Batch index (sequence length) |
| 33 | + int ki = blockIdx.y; // Sequence index |
| 34 | + |
| 35 | + if (i < seqlen && ki < k) { |
| 36 | + int idx = scatter_index[i * k + ki]; // Index into x |
| 37 | + |
| 38 | + // Loop over h dimension in strides of block |
| 39 | + for (int h_i = threadIdx.x * VecSize; h_i < hidden_size; |
| 40 | + h_i += blockDim.x * VecSize) { |
| 41 | + phi::Load<T, VecSize>(&(grad_y[i * hidden_size + h_i]), &grad_y_vec); |
| 42 | + phi::Store<T, VecSize>(grad_y_vec, &grad_x[idx * hidden_size + h_i]); |
| 43 | + } |
| 44 | + } |
| 45 | +} |
| 46 | + |
| 47 | +template <typename T> |
| 48 | +void moe_combine_no_weight_bwd(const int* scatter_index, |
| 49 | + const T* grad_y, |
| 50 | + T* grad_x, |
| 51 | + const int64_t k, |
| 52 | + const int64_t seqlen, |
| 53 | + const int64_t hidden_size, |
| 54 | + cudaStream_t stream) { |
| 55 | + int block_size = 512; |
| 56 | + int grid_size_i = seqlen; |
| 57 | + int grid_size_k = k; |
| 58 | + dim3 blockDim(block_size); |
| 59 | + dim3 gridDim(grid_size_i, grid_size_k); |
| 60 | + |
| 61 | + constexpr int max_pack_size = 16 / sizeof(T); |
| 62 | + if (hidden_size % max_pack_size == 0) { |
| 63 | + combine_no_weight_bwd_kernel<T, float, max_pack_size> |
| 64 | + <<<gridDim, blockDim, 0, stream>>>( |
| 65 | + scatter_index, grad_y, grad_x, k, seqlen, hidden_size); |
| 66 | + } else { |
| 67 | + combine_no_weight_bwd_kernel<T, float, 1><<<gridDim, blockDim, 0, stream>>>( |
| 68 | + scatter_index, grad_y, grad_x, k, seqlen, hidden_size); |
| 69 | + } |
| 70 | +} |
| 71 | + |
| 72 | +template <typename T, typename Context> |
| 73 | +void MoeCombineNoWeightGradKernel(const Context& dev_ctx, |
| 74 | + const DenseTensor& x, |
| 75 | + const DenseTensor& scatter_index, |
| 76 | + const DenseTensor& grad_y, |
| 77 | + DenseTensor* grad_x) { |
| 78 | + const auto x_shape = x.dims(); |
| 79 | + const int64_t hidden_size = x_shape[1]; |
| 80 | + |
| 81 | + const auto scatter_index_shape = scatter_index.dims(); |
| 82 | + const int64_t seqlen = scatter_index_shape[0]; |
| 83 | + const int64_t k = scatter_index_shape[1]; |
| 84 | + |
| 85 | + dev_ctx.template Alloc<T>(grad_x); |
| 86 | + phi::Full<T, Context>( |
| 87 | + dev_ctx, phi::IntArray(common::vectorize(grad_x->dims())), 0, grad_x); |
| 88 | + |
| 89 | + moe_combine_no_weight_bwd<T>(scatter_index.data<int>(), |
| 90 | + grad_y.data<T>(), |
| 91 | + grad_x->data<T>(), |
| 92 | + k, |
| 93 | + seqlen, |
| 94 | + hidden_size, |
| 95 | + dev_ctx.stream()); |
| 96 | +} |
| 97 | + |
| 98 | +} // namespace phi |
| 99 | + |
| 100 | +PD_REGISTER_KERNEL(moe_combine_no_weight_grad, |
| 101 | + GPU, |
| 102 | + ALL_LAYOUT, |
| 103 | + phi::MoeCombineNoWeightGradKernel, |
| 104 | + float, |
| 105 | + double, |
| 106 | + phi::dtype::bfloat16, |
| 107 | + phi::dtype::float16) {} |
0 commit comments