diff --git a/.gitmodules b/.gitmodules index eab6041af..d8f4107a5 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +1,6 @@ [submodule "third_party/spdlog"] path = third_party/spdlog url = https://github.com/gabime/spdlog.git +[submodule "moe_gu_ops/src/nvidia_kernels/cutlass"] + path = moe_gu_ops/src/nvidia_kernels/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/moe_gu_ops/pybind_gumoe.cc b/moe_gu_ops/pybind_gumoe.cc new file mode 100644 index 000000000..12929455e --- /dev/null +++ b/moe_gu_ops/pybind_gumoe.cc @@ -0,0 +1,108 @@ +#include +#include +#include // 必须包含,用于自动转换 map 和 string + +#include "gu_moe.h" // MoE 头文件 +#include "infinicore/tensor.hpp" +#include "infinicore/device.hpp" + +namespace py = pybind11; + +// 1. 转换器:Torch Tensor -> Infini Tensor (Zero-Copy) +// 这个函数只创建一个"视图",不拷贝数据。 +// 安全性:因为 Module::load_state_dict 内部会执行 copy_from,所以数据会被安全地拷贝到模型参数里。 +infinicore::Tensor torch_to_infini_view(const torch::Tensor& t) { + // 1. 获取形状 + infinicore::Shape shape; + for (auto s : t.sizes()) shape.push_back(s); + + // 2. 获取数据类型 (目前代码只支持 F32) + infinicore::DataType dtype = infinicore::DataType::F32; + if (t.dtype() == torch::kFloat32) dtype = infinicore::DataType::F32; + else if (t.dtype() == torch::kFloat16) dtype = infinicore::DataType::F16; + else throw std::runtime_error("Unsupported dtype"); + + // 3. 获取设备 + infinicore::Device::Type dev_type = infinicore::Device::Type::CPU; + int dev_id = 0; + if (t.is_cuda()) { + dev_type = infinicore::Device::Type::NVIDIA; + dev_id = t.device().index(); + } + + // 4. 创建 Tensor 视图 (from_blob) + return infinicore::Tensor::from_blob( + t.data_ptr(), + shape, + dtype, + infinicore::Device(dev_type, dev_id) + ); +} + +// ===================================================================== +// 2. 转换器:Infini Tensor -> Torch Tensor (用于 Forward 输出) +// ===================================================================== +torch::Tensor infini_to_torch_copy(infinicore::Tensor t) { + std::vector sizes; + for (auto s : t->shape()) sizes.push_back(s); + + auto options = torch::TensorOptions().dtype(torch::kFloat32); // 假设输出 F32 + if (t->device().getType() == infinicore::Device::Type::NVIDIA) { + options = options.device(torch::kCUDA, t->device().getIndex()); + } else { + options = options.device(torch::kCPU); + } + + // 创建并 clone,确保拥有内存 + return torch::from_blob(t->data(), sizes, options).clone(); +} + +// ===================================================================== +// 3. 包装类 (Wrapper) +// ===================================================================== +class PyGuMoeWrapper { +public: + std::shared_ptr moe_block; + + // 构造函数:接收 Python 传来的参数 + PyGuMoeWrapper(int num_experts, int hidden_dim, int intermediate_dim, int top_k, bool norm_topk) { + // 假设这里强制使用 NVIDIA:0,你可以根据需要添加 device 参数 + infinicore::Device device(infinicore::Device::Type::NVIDIA, 0); + + moe_block = std::make_shared( + num_experts, hidden_dim, intermediate_dim, top_k, norm_topk, + infinicore::DataType::F32, device + ); + } + + // Forward + torch::Tensor forward(torch::Tensor hidden_states) { + auto infini_input = torch_to_infini_view(hidden_states); + auto infini_output = moe_block->forward(infini_input); + return infini_to_torch_copy(infini_output); + } + + // 【核心】加载权重接口 + // 接收 Python 的 Dict[str, Tensor] + void load_state_dict(std::map weights) { + std::unordered_map infini_dict; + + for (auto const& [name, tensor] : weights) { + infini_dict.emplace(name, torch_to_infini_view(tensor.contiguous())); + } + moe_block->load_state_dict(infini_dict); + + std::cout << "[C++] load_state_dict finished. Loaded " << infini_dict.size() << " tensors." << std::endl; + } +}; + +// ===================================================================== +// 4. 定义 Python 模块 +// ===================================================================== +PYBIND11_MODULE(gu_moe_ops, m) { + py::class_(m, "GuMoeBlock") + .def(py::init()) + .def("forward", &PyGuMoeWrapper::forward) + .def("load_state_dict", &PyGuMoeWrapper::load_state_dict); +} + diff --git a/moe_gu_ops/setup.py b/moe_gu_ops/setup.py new file mode 100644 index 000000000..cbd04ab5f --- /dev/null +++ b/moe_gu_ops/setup.py @@ -0,0 +1,99 @@ +import os +import sys + +try: + import torch + from torch.utils import cpp_extension + def no_op_check(compiler_name, compiler_version): + return True + cpp_extension._check_cuda_version = no_op_check + os.environ["TORCH_DONT_CHECK_CUDA_VERSION_COMPATIBILITY"] = "1" +except ImportError: + pass +# ================================================================================= + +if hasattr(torch, '_C'): + torch._C._GLIBCXX_USE_CXX11_ABI = True + +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import pybind11 + +# 获取当前目录 (即 .../InfiniCore/gu_moe_ops) +CURRENT_DIR = os.path.dirname(os.path.abspath(__file__)) + +# ------------------------------------------------------------------------- +# 1. 路径配置 (动态推导) +# ------------------------------------------------------------------------- +# 优先使用环境变量,如果没有,则根据目录结构自动推导 +# 假设结构是: +# InfiniCore/ +# ├── gu_moe_ops/ <-- 我们在这里 +# ├── include/ +# └── build/ +INFINI_SRC_ROOT = os.getenv("INFINICORE_ROOT") +if not INFINI_SRC_ROOT: + # 向上找一级,就是 InfiniCore 根目录 + INFINI_SRC_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "../")) + +INFINI_LIB_DIR = os.getenv("INFINICORE_BUILD_DIR") +if not INFINI_LIB_DIR: + # 默认构建路径 + INFINI_LIB_DIR = os.path.join(INFINI_SRC_ROOT, 'build/linux/x86_64/release') + +print(f"[Info] InfiniCore Root: {INFINI_SRC_ROOT}") +print(f"[Info] InfiniCore Lib Dir: {INFINI_LIB_DIR}") + +# 检查一下库目录是否存在,防止后面报错看不懂 +if not os.path.exists(INFINI_LIB_DIR): + print(f"[Warning] Library directory not found: {INFINI_LIB_DIR}") + print("Please check if InfiniCore is built or set INFINICORE_BUILD_DIR env var.") + +# ------------------------------------------------------------------------- +# 2. 库列表 +# ------------------------------------------------------------------------- +libs = [ + os.path.join(INFINI_LIB_DIR, 'libinfini-utils.a'), + os.path.join(INFINI_LIB_DIR, 'libinfiniop-nvidia.a'), + os.path.join(INFINI_LIB_DIR, 'libinfiniccl-nvidia.a'), + os.path.join(INFINI_LIB_DIR, 'libinfinirt-nvidia.a') +] + +setup( + name='gu_moe_ops', + version='0.1.0', + ext_modules=[ + CUDAExtension( + name='gu_moe_ops', + sources=[ + 'pybind_gumoe.cc', + 'src/gumoe.cpp', + 'src/gu_mul.cc', + 'src/gu_topk_softmax.cc', + 'src/nvidia_kernels/gu_reduce.cu', + 'src/nvidia_kernels/gu_sort.cu', + ], + include_dirs=[ + pybind11.get_include(), + os.path.join(INFINI_SRC_ROOT, 'include'), # 动态引用父目录的 include + os.path.join(CURRENT_DIR, 'src'), + "/usr/local/cuda/include" + ], + extra_objects=libs, + + extra_link_args=[ + '-Wl,--allow-shlib-undefined' + ], + + extra_compile_args={ + 'cxx': ['-O3', '-std=c++17', '-D_GLIBCXX_USE_CXX11_ABI=1'], + 'nvcc': [ + '-O3', '--use_fast_math', + # 注意:移除了硬编码的 -gencode 参数 + # 编译时会自动读取环境变量 TORCH_CUDA_ARCH_LIST + ] + } + ) + ], + cmdclass={ 'build_ext': BuildExtension } +) \ No newline at end of file diff --git a/moe_gu_ops/src/gu_moe.h b/moe_gu_ops/src/gu_moe.h new file mode 100644 index 000000000..42fa43426 --- /dev/null +++ b/moe_gu_ops/src/gu_moe.h @@ -0,0 +1,69 @@ +#ifndef GU_MOE_H +#define GU_MOE_H + +#include +#include +#include +#include "infinicore/tensor.hpp" +#include "infinicore/nn/module.hpp" +#include "infinicore/device.hpp" +#include "infinicore/ops.hpp" + +namespace infinicore::nn { + +// Router +class GuMoeTopkRounter : public Module { +private: + int top_k_; + int num_experts_; + int hidden_dim_; + bool norm_topk_prob_; + Parameter weight_; + infiniopHandle_t handle_ = nullptr; +public: + GuMoeTopkRounter(int num_experts, int hidden_dim, int top_k, bool norm_topk_prob, + const DataType &dtype, const Device &device); + ~GuMoeTopkRounter(); + void set_weight(Tensor w); // 设置路由权重 + std::pair forward(const Tensor &hidden_states) const; +}; + +// Experts +class GuMoeExperts : public Module { +private: + int num_experts_; + int hidden_dim_; + int intermediate_dim_; + Parameter gate_up_proj_; + Parameter down_proj_; + infiniopHandle_t handle_ = nullptr; + Device device_; + +public: + GuMoeExperts(int num_experts, int hidden_dim, int intermediate_dim, + const DataType& dtype, const Device& device); + ~GuMoeExperts(); + void set_weights(Tensor gate_up, Tensor down); // 设置专家权重 + Tensor forward(const Tensor& hidden_states, const Tensor& top_k_index, const Tensor& top_k_values) const; +}; + +// Block (MoE 整体) +class GuMoeSparseMoeBlock : public Module { +private: + std::shared_ptr router_; + std::shared_ptr experts_; + +public: + GuMoeSparseMoeBlock(int num_experts, int hidden_dim, int intermediate_dim, + int top_k, bool norm_topk, + const DataType& dtype, const Device& device); + + // ✅ 统一设置权重的接口 + void set_weights(Parameter gate_up, Parameter down, Parameter router_weight); + + // ✅ 关键:Forward 只需要 Input (Router 在内部计算) + Tensor forward(const Tensor& hidden_states); +}; + +} // namespace +#endif \ No newline at end of file diff --git a/moe_gu_ops/src/gu_mul.cc b/moe_gu_ops/src/gu_mul.cc new file mode 100644 index 000000000..d7d16f579 --- /dev/null +++ b/moe_gu_ops/src/gu_mul.cc @@ -0,0 +1,49 @@ +#include "infinicore/tensor.hpp" +#include "infinicore/context/context.hpp" +#include "infiniop/ops/mul.h" +#include "gu_mul.h" + +namespace infinicore::op { + +// 【修改点 1】函数签名增加 infiniopHandle_t handle +Tensor mul(Tensor a, Tensor b, infiniopHandle_t handle) { + + Tensor c = Tensor::empty(a->shape(), a->dtype(), a->device()); + + infiniopMulDescriptor_t desc; + infiniopCreateMulDescriptor( + handle, + &desc, + c->desc(), + a->desc(), + b->desc() + ); + + size_t workspace_size = 0; + std::shared_ptr workspace_mem = nullptr; + void* workspace = nullptr; + infiniopGetMulWorkspaceSize(desc, &workspace_size); + + if (workspace_size > 0) { + workspace_mem = context::allocateMemory(workspace_size); + workspace = workspace_mem->data(); + } + + void* stream = nullptr; + + infiniopMul( + desc, + workspace, + workspace_size, + c->data(), + a->data(), + b->data(), + stream + ); + + infiniopDestroyMulDescriptor(desc); + + return c; +} + +} // namespace infinicore::op \ No newline at end of file diff --git a/moe_gu_ops/src/gu_mul.h b/moe_gu_ops/src/gu_mul.h new file mode 100644 index 000000000..6dd06cd7a --- /dev/null +++ b/moe_gu_ops/src/gu_mul.h @@ -0,0 +1,14 @@ +#ifndef GU_MUL_H +#define GU_MUL_H + +#include "infinicore/tensor.hpp" +#include "infinicore/context/context.hpp" +#include "infiniop/ops/mul.h" + +namespace infinicore::op { + +Tensor mul(Tensor a, Tensor b, infiniopHandle_t handle); + +} // namespace infinicore::op + +#endif // GU_MUL_H \ No newline at end of file diff --git a/moe_gu_ops/src/gu_topk_softmax.cc b/moe_gu_ops/src/gu_topk_softmax.cc new file mode 100644 index 000000000..14cbe2e73 --- /dev/null +++ b/moe_gu_ops/src/gu_topk_softmax.cc @@ -0,0 +1,152 @@ +// #include "gu_topk_softmax.h" +// #include "infinicore/context/context.hpp" +// // 必须包含底层的 C 接口定义 +// #include "infiniop/ops/topksoftmax.h" + +// namespace infinicore::op { + +// std::pair topk_softmax(Tensor input, int k, bool normalize, infiniopHandle_t handle) { +// // 1. 手动计算输出形状 +// // input: [..., Hidden] -> output: [..., K] +// // 假设 TopK 操作在最后一维进行 +// Shape out_shape = input->shape(); +// out_shape[out_shape.size() - 1] = k; + +// // 2. 创建输出 Tensors (分配物理内存) +// // values (概率值) 类型与 input 一致 +// Tensor values = Tensor::empty(out_shape, input->dtype(), input->device()); +// // indices (索引) 类型通常是 int32 (从源码 (int *)indices 推断) +// Tensor indices = Tensor::empty(out_shape, DataType::I32, input->device()); + +// // 3. 创建描述符 (Create Descriptor) +// // 根据源码: infiniopCreateTopksoftmaxDescriptor(handle, &desc_ptr, x_desc) +// // 只传入了 input 的描述符 +// infiniopTopksoftmaxDescriptor_t desc; +// infiniopCreateTopksoftmaxDescriptor( +// handle, +// &desc, +// input->desc() +// ); + +// // 4. 申请 Workspace +// size_t workspace_size = 0; +// infiniopGetTopksoftmaxWorkspaceSize(desc, &workspace_size); + +// // 使用智能指针管理内存 (RAII) +// std::shared_ptr workspace_mem = nullptr; +// void* workspace = nullptr; + +// if (workspace_size > 0) { +// workspace_mem = context::allocateMemory(workspace_size); +// workspace = workspace_mem->data(); +// } + +// // 5. 执行计算 (Calculate) +// // 根据源码: infiniopTopksoftmax(desc, ws, ws_size, val, idx, x, topk, norm, stream) +// // topk 和 norm 是在这里传入的! +// void* stream = nullptr; +// // stream = context::getStream(); + +// infiniopTopksoftmax( +// desc, +// workspace, +// workspace_size, +// values->data(), // void* values +// indices->data(), // void* indices +// input->data(), // const void* x +// static_cast(k), // const size_t topk +// normalize ? 1 : 0, // const int norm (0 or 1) +// stream +// ); + +// // 6. 销毁描述符 +// infiniopDestroyTopksoftmaxDescriptor(desc); + +// return {values, indices}; +// } + +// } // namespace infinicore::op + +#include "gu_topk_softmax.h" +#include "infinicore/context/context.hpp" +#include "infiniop/ops/topksoftmax.h" +#include "infinirt.h" // 用于同步 +#include + +namespace infinicore::op { + +// 辅助宏:检查状态,出错则抛异常 +#define CHECK_STATUS(call, msg) \ + do { \ + auto status = (call); \ + if (status != INFINI_STATUS_SUCCESS) { \ + std::string err_msg = std::string("[TopK_Softmax Error] ") + msg + " (Status Code: " + std::to_string(status) + ")"; \ + std::cerr << err_msg << std::endl; \ + throw std::runtime_error(err_msg); \ + } \ + } while (0) + +std::pair topk_softmax(Tensor input, int k, bool normalize, infiniopHandle_t handle) { + Shape out_shape = input->shape(); + out_shape[out_shape.size() - 1] = k; + + // 1. 创建输出张量 + // values: 概率 + Tensor values = Tensor::empty(out_shape, input->dtype(), input->device()); + // indices: 索引 (I32) + Tensor indices = Tensor::empty(out_shape, DataType::I32, input->device()); + + // 2. 创建算子描述符 + infiniopTopksoftmaxDescriptor_t desc; + CHECK_STATUS( + infiniopCreateTopksoftmaxDescriptor(handle, &desc, input->desc()), + "Failed to create descriptor" + ); + + // 3. 申请 Workspace + size_t workspace_size = 0; + CHECK_STATUS( + infiniopGetTopksoftmaxWorkspaceSize(desc, &workspace_size), + "Failed to get workspace size" + ); + + std::shared_ptr workspace_mem = nullptr; + void* workspace = nullptr; + if (workspace_size > 0) { + workspace_mem = context::allocateMemory(workspace_size); + workspace = workspace_mem->data(); + } + + void* stream = nullptr; + + // 4. 执行计算 (Execute) + // 注意:如果这里失败,会直接抛出异常,而不是返回全0 + CHECK_STATUS( + infiniopTopksoftmax( + desc, + workspace, + workspace_size, + values->data(), // Arg 1: Values (Probs) - 按照你的意愿保持原样 + indices->data(), // Arg 2: Indices (Ints) - 按照你的意愿保持原样 + input->data(), + static_cast(k), + normalize ? 1 : 0, + stream + ), + "Kernel execution failed" + ); + + // 5. 销毁描述符 (防止资源泄漏) + CHECK_STATUS( + infiniopDestroyTopksoftmaxDescriptor(desc), + "Failed to destroy descriptor" + ); + + // 6. 【关键】强制同步 + // 防止因为 GPU 还没算完,后续代码就去读,导致读到 0 + infinirtDeviceSynchronize(); + + return {values, indices}; +} + +} // namespace infinicore::op \ No newline at end of file diff --git a/moe_gu_ops/src/gu_topk_softmax.h b/moe_gu_ops/src/gu_topk_softmax.h new file mode 100644 index 000000000..3659c4b48 --- /dev/null +++ b/moe_gu_ops/src/gu_topk_softmax.h @@ -0,0 +1,25 @@ +#ifndef GU_TOPK_SOFTMAX_H +#define GU_TOPK_SOFTMAX_H + +#include "infinicore/tensor.hpp" +#include "infinicore/ops.hpp" +#include "infinicore/nn.hpp" +#include "infinicore/device.hpp" +#include "infinicore/ops/linear.hpp" +#include "infiniop/ops/topksoftmax.h" +#include +#include +#include // for std::pair + +namespace infinicore::op { + +std::pair topk_softmax( + Tensor input, + int k, + bool normalize, + infiniopHandle_t handle +); + +} // namespace infinicore::op + +#endif // GU_TOPK_SOFTMAX_H \ No newline at end of file diff --git a/moe_gu_ops/src/gumoe.cpp b/moe_gu_ops/src/gumoe.cpp new file mode 100644 index 000000000..6731093ef --- /dev/null +++ b/moe_gu_ops/src/gumoe.cpp @@ -0,0 +1,152 @@ +#include "gu_moe.h" +#include +#include +#include +#include +#include +#include +#include + +#include "src/nvidia_kernels/nvidia_kernels_moe.h" +#include "infinicore/ops.hpp" +#include "infinirt.h" +#include "infiniop.h" +#include "gu_mul.h" +#include "gu_topk_softmax.h" + +#define LOG_ERR(fmt, ...) fprintf(stderr, "[ERROR] " fmt "\n", ##__VA_ARGS__) +#define CHECK_CUDA(call) \ + do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + LOG_ERR("CUDA Error at line %d: %s", __LINE__, cudaGetErrorString(err)); \ + } \ + } while (0) + +namespace infinicore::nn { + +GuMoeTopkRounter::GuMoeTopkRounter(int num_experts, int hidden_dim, int top_k, bool norm_topk_prob, const DataType &dtype, const Device &device) + : top_k_(top_k), num_experts_(num_experts), hidden_dim_(hidden_dim), norm_topk_prob_(norm_topk_prob) { + infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); + infiniopCreateHandle(&this->handle_); + INFINICORE_NN_PARAMETER_INIT(weight, ({ {static_cast(num_experts_), static_cast(hidden_dim_)}, dtype, device })); +} +GuMoeTopkRounter::~GuMoeTopkRounter() { if (handle_) infiniopDestroyHandle(handle_); } +std::pair GuMoeTopkRounter::forward(const Tensor &hidden_states) const { + size_t total_tokens = hidden_states->numel() / hidden_dim_; + Tensor flattened = hidden_states->view({total_tokens, static_cast(hidden_dim_)}); + Tensor logits = infinicore::op::linear(flattened, weight_, std::nullopt); + auto [val, idx] = infinicore::op::topk_softmax(logits, top_k_, norm_topk_prob_, this->handle_); + return {val, idx}; +} + +GuMoeExperts::GuMoeExperts(int num_experts, int hidden_dim, int intermediate_dim, const DataType& dtype, const Device& device) + : num_experts_(num_experts), hidden_dim_(hidden_dim), intermediate_dim_(intermediate_dim), device_(device) { + infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); + infiniopCreateHandle(&this->handle_); + INFINICORE_NN_PARAMETER_INIT(gate_up_proj, ({ {static_cast(num_experts), static_cast(2 * intermediate_dim), static_cast(hidden_dim)}, dtype, device })); + INFINICORE_NN_PARAMETER_INIT(down_proj, ({ {static_cast(num_experts), static_cast(hidden_dim), static_cast(intermediate_dim)}, dtype, device })); +} +GuMoeExperts::~GuMoeExperts() { if (handle_) infiniopDestroyHandle(handle_); } + +Tensor GuMoeExperts::forward(const Tensor& hidden_states, const Tensor& top_k_index, const Tensor& top_k_values) const { + Device device = hidden_states->device(); + cudaStream_t stream = 0; + + size_t num_tokens = hidden_states->numel() / hidden_dim_; + int top_k = top_k_index->shape()[1]; + size_t expanded_size = num_tokens * top_k; + + Tensor indices_i32 = Tensor::empty(top_k_index->shape(), DataType::I32, device); + size_t num_elements = top_k_index->numel(); + + // 简单的启发式检查:如果 Dtype 是 5, 6, 7(I64), 0(F32) + int type_id = (int)top_k_index->dtype(); + + if (type_id == 7) { + Tensor cpu_indices = top_k_index->to(Device(Device::Type::CPU)); + std::vector vec_i32(num_elements); + const int64_t* ptr = (const int64_t*)cpu_indices->data(); + for(size_t i=0; i(ptr[i]); + CHECK_CUDA(cudaMemcpyAsync(indices_i32->data(), vec_i32.data(), vec_i32.size() * sizeof(int32_t), cudaMemcpyHostToDevice, stream)); + } + else { + CHECK_CUDA(cudaMemcpyAsync(indices_i32->data(), top_k_index->data(), num_elements * sizeof(int32_t), cudaMemcpyDeviceToDevice, stream)); + } + + Tensor expert_counts = Tensor::zeros({(size_t)num_experts_ + 1}, DataType::I32, device); + Tensor expert_offsets = Tensor::zeros({(size_t)num_experts_ + 1}, DataType::I32, device); + Tensor sorted_input = Tensor::empty({expanded_size, (size_t)hidden_dim_}, DataType::F32, device); + Tensor sorted_output = Tensor::empty({expanded_size, (size_t)hidden_dim_}, DataType::F32, device); + Tensor sorted_row_map = Tensor::empty({expanded_size}, DataType::I32, device); + Tensor sorted_weights = Tensor::empty({expanded_size}, DataType::F32, device); + Tensor final_output = Tensor::zeros(hidden_states->shape(), DataType::F32, device); + + // 2. 排序 + launch_moe_sort( + (int32_t*)indices_i32->data(), + (int32_t*)expert_counts->data(), + (int32_t*)expert_offsets->data(), + (int)num_tokens, top_k, num_experts_, stream + ); + + launch_moe_permute( + (float*)hidden_states->data(), + (int32_t*)indices_i32->data(), + (float*)top_k_values->data(), + (int32_t*)expert_offsets->data(), + (float*)sorted_input->data(), + (int32_t*)sorted_row_map->data(), + (float*)sorted_weights->data(), + (int32_t*)expert_counts->data(), + (int)num_tokens, top_k, hidden_dim_, num_experts_, stream + ); + + // 3. 拷回 Offsets 供循环使用 + std::vector h_offsets(num_experts_ + 1); + CHECK_CUDA(cudaMemcpyAsync(h_offsets.data(), expert_offsets->data(), sizeof(int32_t) * (num_experts_ + 1), cudaMemcpyDeviceToHost, stream)); + CHECK_CUDA(cudaStreamSynchronize(stream)); // 必须同步等 offsets 回来 + + // 4. 计算 Loop + for (int e = 0; e < num_experts_; ++e) { + int start_idx = h_offsets[e]; + int count = h_offsets[e+1] - start_idx; + + if (count <= 0) continue; + + { + Tensor expert_in = sorted_input->narrow({{0, (size_t)start_idx, (size_t)count}}); + Tensor w_gate_up = gate_up_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)(2*intermediate_dim_), (size_t)hidden_dim_}); + Tensor w_down = down_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)hidden_dim_, (size_t)intermediate_dim_}); + + Tensor gate_up_out = infinicore::op::linear(expert_in, w_gate_up, std::nullopt); + Tensor gate = gate_up_out->narrow({{1, 0, (size_t)intermediate_dim_}}); + Tensor up = gate_up_out->narrow({{1, (size_t)intermediate_dim_, (size_t)intermediate_dim_}}); + + Tensor ffn_inner = infinicore::op::mul(infinicore::op::silu(gate), up, this->handle_); + Tensor expert_res = infinicore::op::linear(ffn_inner, w_down, std::nullopt); + + CHECK_CUDA(cudaMemcpyAsync((float*)sorted_output->data() + start_idx * hidden_dim_, (float*)expert_res->data(), (size_t)count * hidden_dim_ * sizeof(float), cudaMemcpyDeviceToDevice, stream)); + } + } + + launch_moe_reduce((float*)sorted_output->data(), (int32_t*)sorted_row_map->data(), (float*)sorted_weights->data(), (float*)final_output->data(), (int)num_tokens, top_k, hidden_dim_, stream); + + // 如果是最后一步,通常不需要显式同步,除非后续逻辑需要 + // cudaStreamSynchronize(stream); + return final_output; +} + +GuMoeSparseMoeBlock::GuMoeSparseMoeBlock(int num_experts, int hidden_dim, int intermediate_dim, int top_k, bool norm_topk, const DataType& dtype, const Device& device) { + router_ = register_module("router", num_experts, hidden_dim, top_k, norm_topk, dtype, device); + experts_ = register_module("experts", num_experts, hidden_dim, intermediate_dim, dtype, device); +} +Tensor GuMoeSparseMoeBlock::forward(const Tensor& hidden_states) { + size_t total_tokens = hidden_states->numel() / (hidden_states->shape().back()); + Tensor hidden_states_reshaped = hidden_states->view({total_tokens, hidden_states->shape().back()}); + auto [routing_weights, selected_experts] = router_->forward(hidden_states_reshaped); + Tensor final_hidden_states = experts_->forward(hidden_states_reshaped, selected_experts, routing_weights); + return final_hidden_states->view(hidden_states->shape()); +} + +} // namespace infinicore::nn \ No newline at end of file diff --git a/moe_gu_ops/src/gumoe_wrapper.py b/moe_gu_ops/src/gumoe_wrapper.py new file mode 100644 index 000000000..30a30d44f --- /dev/null +++ b/moe_gu_ops/src/gumoe_wrapper.py @@ -0,0 +1,65 @@ +import torch +import numpy as np +import gu_moe_ops # 编译出的库 + +class TensorWrapper: + """把 Torch Tensor 包装成 C++ 能读懂的简单对象""" + def __init__(self, tensor): + if not tensor.is_contiguous(): + tensor = tensor.contiguous() + self._t = tensor # 保持引用,防止被回收 + + self.ptr = tensor.data_ptr() + self.shape = list(tensor.shape) + self.device_id = tensor.device.index + + # 简单类型映射 + if tensor.dtype == torch.float32: self.dtype_id = 0 + elif tensor.dtype == torch.bfloat16: self.dtype_id = 1 + elif tensor.dtype == torch.int32: self.dtype_id = 2 + else: raise ValueError(f"Unsupported dtype: {tensor.dtype}") + +class GuMoeBlock(torch.nn.Module): + def __init__(self, num_experts, hidden_size, intermediate_size, top_k): + super().__init__() + # 初始化 C++ 对象 + self.cpp_block = gu_moe_ops.GuMoeBlock( + num_experts, + hidden_size, + intermediate_size, + 0, # dtype=F32 + torch.cuda.current_device() + ) + self.top_k = top_k + + def load_weights(self, state_dict): + # 1. 过滤并重命名权重 + clean_weights = {} + # 假设原始 key 是 "model.layers.0.moe..." + # 我们需要在 Python 里做映射,变成 "moe.gate_up_proj" + # 这里仅作演示,具体映射逻辑看你的模型结构 + + for k, v in state_dict.items(): + # 必须转为 Numpy 且连续 + arr = v.cpu().float().numpy() + if not arr.flags['C_CONTIGUOUS']: + arr = np.ascontiguousarray(arr) + clean_weights[k] = arr + + # 2. 调用 C++ Loader + self.cpp_block.load_weights(clean_weights) + + def forward(self, hidden_states, top_k_indices, top_k_values): + # 1. 包装输入 + input_w = TensorWrapper(hidden_states) + idx_w = TensorWrapper(top_k_indices) + val_w = TensorWrapper(top_k_values) + + # 2. 预分配输出 (Output Buffer) + output = torch.empty_like(hidden_states) + output_w = TensorWrapper(output) + + # 3. 调用 C++ + self.cpp_block.forward(input_w, idx_w, val_w, output_w) + + return output \ No newline at end of file diff --git a/moe_gu_ops/src/history.cc b/moe_gu_ops/src/history.cc new file mode 100644 index 000000000..9e43a9912 --- /dev/null +++ b/moe_gu_ops/src/history.cc @@ -0,0 +1,2303 @@ +GuMoeExperts::GuMoeExperts(int num_experts, int hidden_dim, int intermediate_dim, const DataType& dtype, const Device& device) + +: num_experts_(num_experts), hidden_dim_(hidden_dim), intermediate_dim_(intermediate_dim), device_(device) { + +infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); + +infiniopCreateHandle(&this->handle_); + +INFINICORE_NN_PARAMETER_INIT(gate_up_proj, ({ {static_cast(num_experts), static_cast(2 * intermediate_dim), static_cast(hidden_dim)}, dtype, device })); + +INFINICORE_NN_PARAMETER_INIT(down_proj, ({ {static_cast(num_experts), static_cast(hidden_dim), static_cast(intermediate_dim)}, dtype, device })); + +} + +GuMoeExperts::~GuMoeExperts() { if (handle_) infiniopDestroyHandle(handle_); } + + + +Tensor GuMoeExperts::forward(const Tensor& hidden_states, const Tensor& top_k_index, const Tensor& top_k_values) const { + +if (hidden_states->dtype() != DataType::F32) throw std::runtime_error("F32 only"); + + + +Device gpu = hidden_states->device(); + +Device cpu(Device::Type::CPU); + + + +Tensor cpu_indices = top_k_index->to(cpu); + +Tensor cpu_values = top_k_values->to(cpu); + +Tensor cpu_hidden = hidden_states->to(cpu); + + +Tensor final_cpu_states = Tensor::zeros(hidden_states->shape(), hidden_states->dtype(), cpu); + +std::memset(final_cpu_states->data(), 0, final_cpu_states->numel() * sizeof(float)); + + + +size_t total_tokens = hidden_states->numel() / hidden_dim_; + +int top_k = top_k_index->shape()[1]; + + + +struct Task { int token_idx; int rank_idx; }; + +std::vector> buckets(num_experts_); + +const void* raw_idx = cpu_indices->data(); + + +bool is_i32 = (cpu_indices->dtype() == DataType::I32); + +const float* all_vals = (const float*)cpu_values->data(); + + + +// 路由信息打印 (保持你之前的) + +static bool debug_printed = false; + +if (!debug_printed) { + +std::cout << "\n[C++ Debug Info]" << std::endl; + +std::cout << "Token 0 Selected Experts: ["; + +for(int k=0; k(top_k); ++k) { + +int64_t val = is_i32 ? (int64_t)((const int32_t*)raw_idx)[i*top_k+k] : ((const int64_t*)raw_idx)[i*top_k+k]; + +int eid = (int)val; + +if (eid >= 0 && eid < num_experts_) buckets[eid].push_back({(int)i, (int)k}); + +} + +} + + + +infinirtSetDevice((infiniDevice_t)device_.getType(), device_.getIndex()); + + + +for (int e = 0; e < num_experts_; ++e) { + +if (buckets[e].empty()) continue; + +size_t n = buckets[e].size(); + + + +std::vector t_idx(n); + +std::vector t_w(n); + + +// 查找 Token 0 在当前 bucket 中的位置 + +int local_token0_idx = -1; + + + +for(size_t i=0; idtype(), cpu); + +cpu_gather((float*)cpu_in->data(), (const float*)cpu_hidden->data(), t_idx, hidden_dim_); + +Tensor gpu_in = cpu_in->to(gpu); + + + +Tensor w_gate_up = gate_up_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)(2*intermediate_dim_), (size_t)hidden_dim_}); + +Tensor gate_up_out = infinicore::op::linear(gpu_in, w_gate_up, std::nullopt); + + + +// ===================================================================== + +// 🕵️‍♂️ [新增] FFN 中间值探针 (Gate/Up) + +// ===================================================================== + +static bool ffn_debug_printed = false; + +if (!ffn_debug_printed && local_token0_idx != -1) { + +std::cout << "\n[C++ FFN Internal Debug] Expert " << e << " processing Token 0" << std::endl; + + +// 拷回 CPU + +Tensor debug_tensor = gate_up_out->to(cpu); + +const float* ptr = (const float*)debug_tensor->data(); + + +// 定位到 Token 0 的那一行数据 + +// shape: [n, 2 * intermediate_dim] + +size_t row_offset = local_token0_idx * (2 * intermediate_dim_); + +const float* token0_row = ptr + row_offset; + + + +// 打印前半部分 (C++ 认为是 Gate) + +std::cout << " C++ First Half (Gate?): ["; + +for(int j=0; j<5; ++j) std::cout << token0_row[j] << ", "; + +std::cout << "...]" << std::endl; + + + +// 打印后半部分 (C++ 认为是 Up) + +size_t mid = intermediate_dim_; + +std::cout << " C++ Second Half (Up?): ["; + +for(int j=0; j<5; ++j) std::cout << token0_row[mid+j] << ", "; + +std::cout << "...]" << std::endl; + + +ffn_debug_printed = true; + +} + +// ===================================================================== + + + +Tensor gate = gate_up_out->narrow({{1, 0, (size_t)intermediate_dim_}}); + +Tensor up = gate_up_out->narrow({{1, (size_t)intermediate_dim_, (size_t)intermediate_dim_}}); + +Tensor ffn_inner = infinicore::op::mul(infinicore::op::silu(gate), up, this->handle_); + + +Tensor w_down = down_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)hidden_dim_, (size_t)intermediate_dim_}); + +Tensor gpu_res = infinicore::op::linear(ffn_inner, w_down, std::nullopt); + + +Tensor cpu_res = gpu_res->to(cpu); + + + +infinirtDeviceSynchronize(); + + + +cpu_index_add_scale((float*)final_cpu_states->data(), + +(const float*)cpu_res->data(), + +t_idx, t_w, hidden_dim_, + +total_tokens); + +} + + +if (!debug_printed) debug_printed = true; // 防止漏打导致多次 + + + +return final_cpu_states->to(gpu); + +} + +// #include "gu_moe.h" + +// #include +// #include +// #include +// #include +// #include +// #include +// #include + +// #include "src/nvidia_kernels/nvidia_kernels_moe.h" +// #include "infinicore/ops.hpp" +// #include "infinicore/ops/linear.hpp" +// #include "infinirt.h" +// #include "infiniop.h" +// #include "gu_mul.h" +// #include "gu_topk_softmax.h" + +// namespace infinicore::nn { + +// namespace { + +// void debug_tensor(const std::string& name, const Tensor& t, int count=5) { +// Device cpu(Device::Type::CPU); +// Tensor c = t->to(cpu); +// if (c->dtype() == DataType::F32) { +// const float* ptr = reinterpret_cast(c->data()); +// float min_v = 1e30, max_v = -1e30; +// double sum = 0; +// for(size_t i=0; inumel(); ++i) { +// float v = ptr[i]; +// if(v < min_v) min_v = v; +// if(v > max_v) max_v = v; +// sum += std::abs(v); +// } +// std::cout << "[DEBUG] " << name << " | Min: " << min_v << " | Max: " << max_v +// << " | MeanAbs: " << (sum / c->numel()) << std::endl; +// } +// } + +// void cpu_gather(float* dest, const float* src, const std::vector& indices, int hidden_dim) { +// for (size_t i = 0; i < indices.size(); ++i) { +// int row = indices[i]; +// std::memcpy(dest + i * hidden_dim, src + row * hidden_dim, hidden_dim * sizeof(float)); +// } +// } + +// void cpu_index_add_scale(float* dest, const float* src, +// const std::vector& indices, +// const std::vector& weights, +// int hidden_dim, +// size_t total_rows) { +// for (size_t i = 0; i < indices.size(); ++i) { +// int row = indices[i]; +// if (row < 0 || row >= (int)total_rows) continue; +// float w = weights[i]; +// float* d_row = dest + row * hidden_dim; +// const float* s_row = src + i * hidden_dim; +// for (int j = 0; j < hidden_dim; ++j) { +// d_row[j] += s_row[j] * w; +// } +// } +// } + +// } // namespace anonymous + +// // ... GuMoeTopkRounter ... +// GuMoeTopkRounter::GuMoeTopkRounter(int num_experts, int hidden_dim, int top_k, bool norm_topk_prob, const DataType &dtype, const Device &device) +// : top_k_(top_k), num_experts_(num_experts), hidden_dim_(hidden_dim), norm_topk_prob_(norm_topk_prob) { +// infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); +// infiniopCreateHandle(&this->handle_); +// INFINICORE_NN_PARAMETER_INIT(weight, ({ {static_cast(num_experts_), static_cast(hidden_dim_)}, dtype, device })); +// } +// GuMoeTopkRounter::~GuMoeTopkRounter() { if (handle_) infiniopDestroyHandle(handle_); } + +// std::pair GuMoeTopkRounter::forward(const Tensor &hidden_states) const { +// size_t total_tokens = hidden_states->numel() / hidden_dim_; +// Tensor flattened = hidden_states->view({total_tokens, static_cast(hidden_dim_)}); +// Tensor logits = infinicore::op::linear(flattened, weight_, std::nullopt); +// auto [val, idx] = infinicore::op::topk_softmax(logits, top_k_, norm_topk_prob_, this->handle_); +// return {val, idx}; +// } + +// // ... GuMoeExperts ... +// GuMoeExperts::GuMoeExperts(int num_experts, int hidden_dim, int intermediate_dim, const DataType& dtype, const Device& device) +// : num_experts_(num_experts), hidden_dim_(hidden_dim), intermediate_dim_(intermediate_dim), device_(device) { +// infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); +// infiniopCreateHandle(&this->handle_); +// INFINICORE_NN_PARAMETER_INIT(gate_up_proj, ({ {static_cast(num_experts), static_cast(2 * intermediate_dim), static_cast(hidden_dim)}, dtype, device })); +// INFINICORE_NN_PARAMETER_INIT(down_proj, ({ {static_cast(num_experts), static_cast(hidden_dim), static_cast(intermediate_dim)}, dtype, device })); +// } +// GuMoeExperts::~GuMoeExperts() { if (handle_) infiniopDestroyHandle(handle_); } + +// Tensor GuMoeExperts::forward(const Tensor& hidden_states, const Tensor& top_k_index, const Tensor& top_k_values) const { +// if (hidden_states->dtype() != DataType::F32) throw std::runtime_error("F32 only"); + +// // 0. 上下文准备 +// Device device = hidden_states->device(); +// // 假设使用默认流 0。如果 infiniop 支持获取流,建议使用 context::getStream() +// cudaStream_t stream = 0; + +// size_t num_tokens = hidden_states->numel() / hidden_dim_; +// int top_k = top_k_index->shape()[1]; +// size_t expanded_size = num_tokens * top_k; + +// // 1. 分配 GPU 显存 (Workspace) +// // 工业级优化点:这里的 Tensor::zeros/empty 每次 forward 都会申请显存。 +// // 如果追求极致性能,建议在类里维护一个缓存池 (Tensor workspace_)。 + +// // 计数器和偏移量 +// Tensor expert_counts = Tensor::zeros({(size_t)num_experts_}, DataType::I32, device); +// Tensor expert_offsets = Tensor::zeros({(size_t)num_experts_ + 1}, DataType::I32, device); + +// // 中间 buffer (排序后的输入/输出) +// Tensor sorted_input = Tensor::empty({expanded_size, (size_t)hidden_dim_}, DataType::F32, device); +// Tensor sorted_output = Tensor::empty({expanded_size, (size_t)hidden_dim_}, DataType::F32, device); + +// // 辅助信息 (Row Map 和 Weights) +// Tensor sorted_row_map = Tensor::empty({expanded_size}, DataType::I32, device); +// Tensor sorted_weights = Tensor::empty({expanded_size}, DataType::F32, device); + +// // 最终输出 (必须初始化为 0,因为 Reduce 是累加) +// Tensor final_output = Tensor::zeros(hidden_states->shape(), DataType::F32, device); + +// // 获取裸指针 +// float* d_input = (float*)hidden_states->data(); +// int32_t* d_indices = (int32_t*)top_k_index->data(); +// float* d_values = (float*)top_k_values->data(); + +// int32_t* d_counts = (int32_t*)expert_counts->data(); +// int32_t* d_offsets = (int32_t*)expert_offsets->data(); + +// // ====================================================================== +// // Phase 1: 数据重排 (GPU Sort & Permute) +// // 彻底取代原来的 CPU bucket 和 cpu_gather +// // ====================================================================== + +// // 1.1 排序:计算每个专家的 Token 数量和偏移量 +// launch_moe_sort( +// d_indices, d_counts, d_offsets, +// num_tokens, top_k, num_experts_, +// stream +// ); + +// // 1.2 搬运:将 Input 和 Weights 按照专家顺序连续排列到 sorted_input/sorted_weights +// // 注意:复用 expert_counts 作为 running_counters (内部会自动清零) +// launch_moe_permute( +// d_input, +// d_indices, +// d_values, +// d_offsets, +// (float*)sorted_input->data(), +// (int32_t*)sorted_row_map->data(), +// (float*)sorted_weights->data(), +// d_counts, +// num_tokens, top_k, hidden_dim_, num_experts_, +// stream +// ); + +// // ====================================================================== +// // Phase 2: 计算 (GPU Loop) +// // 这里的循环仅用于发射 Kernel,数据全程在 GPU 上,没有拷贝开销 +// // ====================================================================== + +// // 将 Offsets 拷回 CPU,以便 CPU 知道如何对 sorted_input 进行切片 +// std::vector h_offsets(num_experts_ + 1); +// cudaMemcpyAsync(h_offsets.data(), d_offsets, sizeof(int32_t) * (num_experts_ + 1), cudaMemcpyDeviceToHost, stream); +// cudaStreamSynchronize(stream); // 等待 Offset 拷贝完成 + +// for (int e = 0; e < num_experts_; ++e) { +// int start_idx = h_offsets[e]; +// int count = h_offsets[e+1] - start_idx; + +// // 如果该专家没有分配到 Token,跳过 +// if (count == 0) continue; + +// // A. 切片 (Slicing - Zero Copy) +// // 这里的 narrow 只是创建 View,不发生数据搬运 +// // 切出属于当前专家的输入数据 +// Tensor expert_in = sorted_input->narrow({{0, (size_t)start_idx, (size_t)count}}); + +// // 切出当前专家的权重 +// Tensor w_gate_up = gate_up_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)(2*intermediate_dim_), (size_t)hidden_dim_}); +// Tensor w_down = down_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)hidden_dim_, (size_t)intermediate_dim_}); + +// // B. 计算 (Computation - All on GPU) +// // 1. Linear: Input * GateUp +// Tensor gate_up_out = infinicore::op::linear(expert_in, w_gate_up, std::nullopt); + +// // 2. Activation: SiLU(Gate) * Up +// Tensor gate = gate_up_out->narrow({{1, 0, (size_t)intermediate_dim_}}); +// Tensor up = gate_up_out->narrow({{1, (size_t)intermediate_dim_, (size_t)intermediate_dim_}}); + +// // FFN Inner +// Tensor ffn_inner = infinicore::op::mul(infinicore::op::silu(gate), up, this->handle_); + +// // 3. Linear: Inner * Down +// Tensor expert_res = infinicore::op::linear(ffn_inner, w_down, std::nullopt); + +// // C. 写回大 Buffer (Scatter back to sorted_output) +// // infiniop::linear 返回的是新分配的 Tensor,我们需要把它拷贝回 sorted_output 的对应位置 +// // 这一步是 Device-to-Device Copy,速度极快 + +// float* dst_ptr = (float*)sorted_output->data() + start_idx * hidden_dim_; +// const float* src_ptr = (const float*)expert_res->data(); +// size_t bytes = count * hidden_dim_ * sizeof(float); + +// cudaMemcpyAsync(dst_ptr, src_ptr, bytes, cudaMemcpyDeviceToDevice, stream); +// } + +// // ====================================================================== +// // Phase 3: 还原 (GPU Reduce) +// // 使用 sorted_row_map 和 sorted_weights 将结果加权累加回 final_output +// // ====================================================================== + +// launch_moe_reduce( +// (float*)sorted_output->data(), +// (int32_t*)sorted_row_map->data(), +// (float*)sorted_weights->data(), +// (float*)final_output->data(), +// num_tokens, top_k, hidden_dim_, +// stream +// ); + +// return final_output; +// } + +// // ... 保持不变 ... +// GuMoeSparseMoeBlock::GuMoeSparseMoeBlock(int num_experts, int hidden_dim, int intermediate_dim, +// int top_k, bool norm_topk, +// const DataType& dtype, const Device& device) { +// router_ = register_module("router", num_experts, hidden_dim, top_k, norm_topk, dtype, device); +// experts_ = register_module("experts", num_experts, hidden_dim, intermediate_dim, dtype, device); +// } +// Tensor GuMoeSparseMoeBlock::forward(const Tensor& hidden_states) { +// auto input_shape = hidden_states->shape(); +// size_t batch_size = input_shape[0]; +// size_t seq_len = input_shape[1]; +// size_t hidden_dim = input_shape[2]; +// size_t total_tokens = hidden_states->numel() / hidden_dim; +// Tensor hidden_states_reshaped = hidden_states->view({total_tokens, hidden_dim}); +// auto [routing_weights, selected_experts] = router_->forward(hidden_states_reshaped); +// Tensor final_hidden_states = experts_->forward(hidden_states_reshaped, selected_experts, routing_weights); +// return final_hidden_states->view({batch_size, seq_len, hidden_dim}); +// } + +// } // namespace + +// #include "gu_moe.h" + +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include // 补充: 为了 std::get, std::tuple + +// // 确保包含你项目中实际存在的头文件 +// #include "src/nvidia_kernels/nvidia_kernels_moe.h" +// #include "infinicore/ops.hpp" +// // #include "infinicore/ops/linear.hpp" // 如果 ops.hpp 已包含,可注释 +// #include "infinirt.h" +// #include "infiniop.h" +// #include "gu_mul.h" +// // #include "gu_mul.h" // 如果不需要可注释 +// #include "gu_topk_softmax.h" // 确保这个文件存在 + +// namespace infinicore::nn { + +// namespace { + +// void debug_tensor(const std::string& name, const Tensor& t, int count=5) { +// Device cpu(Device::Type::CPU); +// Tensor c = t->to(cpu); +// if (c->dtype() == DataType::F32) { +// const float* ptr = reinterpret_cast(c->data()); +// float min_v = 1e30, max_v = -1e30; +// double sum = 0; +// for(size_t i=0; inumel(); ++i) { +// float v = ptr[i]; +// if(v < min_v) min_v = v; +// if(v > max_v) max_v = v; +// sum += std::abs(v); +// } +// std::cout << "[DEBUG] " << name << " | Min: " << min_v << " | Max: " << max_v +// << " | MeanAbs: " << (sum / c->numel()) << std::endl; +// } +// } + +// } // namespace anonymous + +// // ========================================== +// // GuMoeTopkRounter 实现 +// // ========================================== + +// GuMoeTopkRounter::GuMoeTopkRounter(int num_experts, int hidden_dim, int top_k, bool norm_topk_prob, const DataType &dtype, const Device &device) +// : top_k_(top_k), +// num_experts_(num_experts), +// hidden_dim_(hidden_dim), +// norm_topk_prob_(norm_topk_prob) +// { +// infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); +// infiniopCreateHandle(&this->handle_); +// // 初始化权重,假设宏 INFINICORE_NN_PARAMETER_INIT 会处理赋值 +// INFINICORE_NN_PARAMETER_INIT(weight, ({ {static_cast(num_experts_), static_cast(hidden_dim_)}, dtype, device })); +// } + +// GuMoeTopkRounter::~GuMoeTopkRounter() { +// if (handle_) infiniopDestroyHandle(handle_); +// } + +// std::pair GuMoeTopkRounter::forward(const Tensor &hidden_states) const { +// size_t total_tokens = hidden_states->numel() / hidden_dim_; +// Tensor flattened = hidden_states->view({total_tokens, static_cast(hidden_dim_)}); + +// Tensor logits = infinicore::op::linear(flattened, weight_, std::nullopt); + +// auto [val, idx] = infinicore::op::topk_softmax(logits, top_k_, norm_topk_prob_, this->handle_); + +// return {val, idx}; +// } + +// // ========================================== +// // GuMoeExperts 实现 +// // ========================================== + +// GuMoeExperts::GuMoeExperts(int num_experts, int hidden_dim, int intermediate_dim, const DataType& dtype, const Device& device) +// : num_experts_(num_experts), +// hidden_dim_(hidden_dim), +// intermediate_dim_(intermediate_dim), +// device_(device) +// { +// infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); +// infiniopCreateHandle(&this->handle_); +// INFINICORE_NN_PARAMETER_INIT(gate_up_proj, ({ {static_cast(num_experts), static_cast(2 * intermediate_dim), static_cast(hidden_dim)}, dtype, device })); +// INFINICORE_NN_PARAMETER_INIT(down_proj, ({ {static_cast(num_experts), static_cast(hidden_dim), static_cast(intermediate_dim)}, dtype, device })); +// } + +// GuMoeExperts::~GuMoeExperts() { +// if (handle_) infiniopDestroyHandle(handle_); +// } + +// Tensor GuMoeExperts::forward(const Tensor& hidden_states, const Tensor& top_k_index, const Tensor& top_k_values) const { +// if (hidden_states->dtype() != DataType::F32) throw std::runtime_error("F32 only"); + +// // 0. 上下文准备 +// Device device = hidden_states->device(); +// cudaStream_t stream = 0; // 默认流 + +// size_t num_tokens = hidden_states->numel() / hidden_dim_; +// // 假设 top_k_index shape 是 [num_tokens, top_k] +// int top_k = top_k_index->shape()[1]; +// size_t expanded_size = num_tokens * top_k; + +// // 1. 分配 GPU 显存 (Workspace) +// Tensor expert_counts = Tensor::zeros({(size_t)num_experts_}, DataType::I32, device); +// Tensor expert_offsets = Tensor::zeros({(size_t)num_experts_ + 1}, DataType::I32, device); + +// Tensor sorted_input = Tensor::empty({expanded_size, (size_t)hidden_dim_}, DataType::F32, device); +// Tensor sorted_output = Tensor::empty({expanded_size, (size_t)hidden_dim_}, DataType::F32, device); + +// Tensor sorted_row_map = Tensor::empty({expanded_size}, DataType::I32, device); +// Tensor sorted_weights = Tensor::empty({expanded_size}, DataType::F32, device); + +// Tensor final_output = Tensor::zeros(hidden_states->shape(), DataType::F32, device); + +// // 获取裸指针 +// float* d_input = (float*)hidden_states->data(); +// int32_t* d_indices = (int32_t*)top_k_index->data(); +// float* d_values = (float*)top_k_values->data(); + +// int32_t* d_counts = (int32_t*)expert_counts->data(); +// int32_t* d_offsets = (int32_t*)expert_offsets->data(); + +// // ====================================================================== +// // Phase 1: 数据重排 (GPU Sort & Permute) +// // ====================================================================== + +// launch_moe_sort( +// d_indices, d_counts, d_offsets, +// num_tokens, top_k, num_experts_, +// stream +// ); + +// launch_moe_permute( +// d_input, +// d_indices, +// d_values, +// d_offsets, +// (float*)sorted_input->data(), +// (int32_t*)sorted_row_map->data(), +// (float*)sorted_weights->data(), +// d_counts, +// num_tokens, top_k, hidden_dim_, num_experts_, +// stream +// ); + +// // ====================================================================== +// // Phase 2: 计算 (GPU Loop) +// // ====================================================================== + +// std::vector h_offsets(num_experts_ + 1); +// cudaMemcpyAsync(h_offsets.data(), d_offsets, sizeof(int32_t) * (num_experts_ + 1), cudaMemcpyDeviceToHost, stream); +// cudaStreamSynchronize(stream); + +// for (int e = 0; e < num_experts_; ++e) { +// int start_idx = h_offsets[e]; +// int count = h_offsets[e+1] - start_idx; + +// if (count == 0) continue; + +// // A. 切片 (如果 InfiniCore 确实支持 narrow,这里就没问题) +// // 注意:之前报错说没 narrow,这里保留你的代码。如果再次报错,说明 InfiniCore 只有 slice +// Tensor expert_in = sorted_input->narrow({{0, (size_t)start_idx, (size_t)count}}); + +// Tensor w_gate_up = gate_up_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)(2*intermediate_dim_), (size_t)hidden_dim_}); +// Tensor w_down = down_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)hidden_dim_, (size_t)intermediate_dim_}); + +// // B. 计算 +// Tensor gate_up_out = infinicore::op::linear(expert_in, w_gate_up, std::nullopt); + +// Tensor gate = gate_up_out->narrow({{1, 0, (size_t)intermediate_dim_}}); +// Tensor up = gate_up_out->narrow({{1, (size_t)intermediate_dim_, (size_t)intermediate_dim_}}); + +// // FFN Inner +// Tensor ffn_inner = infinicore::op::mul(infinicore::op::silu(gate), up, this->handle_); + +// Tensor expert_res = infinicore::op::linear(ffn_inner, w_down, std::nullopt); + +// // C. 写回 +// float* dst_ptr = (float*)sorted_output->data() + start_idx * hidden_dim_; +// const float* src_ptr = (const float*)expert_res->data(); +// size_t bytes = count * hidden_dim_ * sizeof(float); + +// cudaMemcpyAsync(dst_ptr, src_ptr, bytes, cudaMemcpyDeviceToDevice, stream); +// } + +// // ====================================================================== +// // Phase 3: 还原 (GPU Reduce) +// // ====================================================================== + +// launch_moe_reduce( +// (float*)sorted_output->data(), +// (int32_t*)sorted_row_map->data(), +// (float*)sorted_weights->data(), +// (float*)final_output->data(), +// num_tokens, top_k, hidden_dim_, +// stream +// ); + +// return final_output; +// } + +// GuMoeSparseMoeBlock::GuMoeSparseMoeBlock(int num_experts, int hidden_dim, int intermediate_dim, +// int top_k, bool norm_topk, +// const DataType& dtype, const Device& device) { +// router_ = register_module("router", num_experts, hidden_dim, top_k, norm_topk, dtype, device); +// experts_ = register_module("experts", num_experts, hidden_dim, intermediate_dim, dtype, device); +// } +// Tensor GuMoeSparseMoeBlock::forward(const Tensor& hidden_states) { +// auto input_shape = hidden_states->shape(); +// size_t batch_size = input_shape[0]; +// size_t seq_len = input_shape[1]; +// size_t hidden_dim = input_shape[2]; +// size_t total_tokens = hidden_states->numel() / hidden_dim; +// Tensor hidden_states_reshaped = hidden_states->view({total_tokens, hidden_dim}); +// auto [routing_weights, selected_experts] = router_->forward(hidden_states_reshaped); +// Tensor final_hidden_states = experts_->forward(hidden_states_reshaped, selected_experts, routing_weights); +// return final_hidden_states->view({batch_size, seq_len, hidden_dim}); +// } + +// } // namespace nn + +// #include "gu_moe.h" + +// #include +// #include +// #include +// #include +// #include +// #include +// #include +// #include + +// #include "src/nvidia_kernels/nvidia_kernels_moe.h" +// #include "infinicore/ops.hpp" +// #include "infinirt.h" +// #include "infiniop.h" +// #include "gu_mul.h" +// #include "gu_topk_softmax.h" + +// namespace infinicore::nn { + +// namespace { + +// void debug_tensor(const std::string& name, const Tensor& t, int count=5) { +// Device cpu(Device::Type::CPU); +// Tensor c = t->to(cpu); +// if (c->dtype() == DataType::F32) { +// const float* ptr = reinterpret_cast(c->data()); +// float min_v = 1e30, max_v = -1e30; +// double sum = 0; +// for(size_t i=0; inumel(); ++i) { +// float v = ptr[i]; +// if(v < min_v) min_v = v; +// if(v > max_v) max_v = v; +// sum += std::abs(v); +// } +// std::cout << "[DEBUG] " << name << " | Min: " << min_v << " | Max: " << max_v +// << " | MeanAbs: " << (sum / c->numel()) << std::endl; +// } +// } + +// } // namespace + +// // ========================================== +// // GuMoeTopkRounter 实现 +// // ========================================== + +// GuMoeTopkRounter::GuMoeTopkRounter(int num_experts, int hidden_dim, int top_k, bool norm_topk_prob, const DataType &dtype, const Device &device) +// : top_k_(top_k), +// num_experts_(num_experts), +// hidden_dim_(hidden_dim), +// norm_topk_prob_(norm_topk_prob) +// { +// infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); +// infiniopCreateHandle(&this->handle_); +// INFINICORE_NN_PARAMETER_INIT(weight, ({ {static_cast(num_experts_), static_cast(hidden_dim_)}, dtype, device })); +// } + +// GuMoeTopkRounter::~GuMoeTopkRounter() { +// if (handle_) infiniopDestroyHandle(handle_); +// } + +// std::pair GuMoeTopkRounter::forward(const Tensor &hidden_states) const { +// size_t total_tokens = hidden_states->numel() / hidden_dim_; +// Tensor flattened = hidden_states->view({total_tokens, static_cast(hidden_dim_)}); +// Tensor logits = infinicore::op::linear(flattened, weight_, std::nullopt); +// auto [val, idx] = infinicore::op::topk_softmax(logits, top_k_, norm_topk_prob_, this->handle_); +// return {val, idx}; +// } + +// // ========================================== +// // GuMoeExperts 实现 +// // ========================================== + +// GuMoeExperts::GuMoeExperts(int num_experts, int hidden_dim, int intermediate_dim, const DataType& dtype, const Device& device) +// : num_experts_(num_experts), +// hidden_dim_(hidden_dim), +// intermediate_dim_(intermediate_dim), +// device_(device) +// { +// infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); +// infiniopCreateHandle(&this->handle_); +// INFINICORE_NN_PARAMETER_INIT(gate_up_proj, ({ {static_cast(num_experts), static_cast(2 * intermediate_dim), static_cast(hidden_dim)}, dtype, device })); +// INFINICORE_NN_PARAMETER_INIT(down_proj, ({ {static_cast(num_experts), static_cast(hidden_dim), static_cast(intermediate_dim)}, dtype, device })); +// } + +// GuMoeExperts::~GuMoeExperts() { +// if (handle_) infiniopDestroyHandle(handle_); +// } + +// Tensor GuMoeExperts::forward(const Tensor& hidden_states, const Tensor& top_k_index, const Tensor& top_k_values) const { +// if (hidden_states->dtype() != DataType::F32) throw std::runtime_error("F32 only"); + +// Device device = hidden_states->device(); +// cudaStream_t stream = 0; + +// size_t num_tokens = hidden_states->numel() / hidden_dim_; +// int top_k = top_k_index->shape()[1]; +// size_t expanded_size = num_tokens * top_k; + +// auto print_shape = [](const Shape& s) { +// std::string out = "["; +// for(size_t i=0; ishape(), sizeof(float)); +// Tensor final_output = Tensor::zeros(hidden_states->shape(), DataType::F32, device); + +// float* d_input = (float*)hidden_states->data(); +// int32_t* d_indices = (int32_t*)top_k_index->data(); +// float* d_values = (float*)top_k_values->data(); +// int32_t* d_counts = (int32_t*)expert_counts->data(); +// int32_t* d_offsets = (int32_t*)expert_offsets->data(); + +// // Phase 1: 排序与重排 (增加检查点) +// std::cout << "[CHECKPOINT] Launching moe_sort..." << std::endl; +// launch_moe_sort(d_indices, d_counts, d_offsets, num_tokens, top_k, num_experts_, stream); + +// std::cout << "[CHECKPOINT] Launching moe_permute..." << std::endl; +// launch_moe_permute( +// d_input, d_indices, d_values, d_offsets, +// (float*)sorted_input->data(), (int32_t*)sorted_row_map->data(), (float*)sorted_weights->data(), +// d_counts, num_tokens, top_k, hidden_dim_, num_experts_, stream +// ); + +// // Phase 2: 计算 +// std::vector h_offsets(num_experts_ + 1); +// std::cout << "[CHECKPOINT] Copying offsets to host..." << std::endl; +// // 使用同步拷贝确保安全性 +// cudaMemcpy(h_offsets.data(), d_offsets, sizeof(int32_t) * (num_experts_ + 1), cudaMemcpyDeviceToHost); + +// for (int e = 0; e < num_experts_; ++e) { +// int start_idx = h_offsets[e]; +// int count = h_offsets[e+1] - start_idx; + +// // 增加数据完整性校验,防止由于 Kernel 错误导致的非法内存申请 +// if (count < 0 || count > (int)expanded_size) { +// std::cerr << "[FATAL] Expert " << e << " has invalid token count: " << count << std::endl; +// continue; +// } +// if (count == 0) continue; + +// if (e % 20 == 0) std::cout << "[CHECKPOINT] Expert loop at " << e << ", count=" << count << std::endl; + +// Tensor expert_in = sorted_input->narrow({{0, (size_t)start_idx, (size_t)count}}); +// Tensor w_gate_up = gate_up_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)(2*intermediate_dim_), (size_t)hidden_dim_}); +// Tensor w_down = down_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)hidden_dim_, (size_t)intermediate_dim_}); + +// Tensor gate_up_out = infinicore::op::linear(expert_in, w_gate_up, std::nullopt); +// Tensor gate = gate_up_out->narrow({{1, 0, (size_t)intermediate_dim_}}); +// Tensor up = gate_up_out->narrow({{1, (size_t)intermediate_dim_, (size_t)intermediate_dim_}}); + +// Tensor ffn_inner = infinicore::op::mul(infinicore::op::silu(gate), up, this->handle_); +// Tensor expert_res = infinicore::op::linear(ffn_inner, w_down, std::nullopt); + +// float* dst_ptr = (float*)sorted_output->data() + start_idx * hidden_dim_; +// cudaMemcpyAsync(dst_ptr, (float*)expert_res->data(), count * hidden_dim_ * sizeof(float), cudaMemcpyDeviceToDevice, stream); +// } + +// // Phase 3: 还原 +// std::cout << "[CHECKPOINT] Launching moe_reduce..." << std::endl; +// launch_moe_reduce( +// (float*)sorted_output->data(), (int32_t*)sorted_row_map->data(), (float*)sorted_weights->data(), +// (float*)final_output->data(), num_tokens, top_k, hidden_dim_, stream +// ); + +// return final_output; +// } + +// // ========================================== +// // GuMoeSparseMoeBlock 实现 +// // ========================================== + +// GuMoeSparseMoeBlock::GuMoeSparseMoeBlock(int num_experts, int hidden_dim, int intermediate_dim, +// int top_k, bool norm_topk, +// const DataType& dtype, const Device& device) { +// router_ = register_module("router", num_experts, hidden_dim, top_k, norm_topk, dtype, device); +// experts_ = register_module("experts", num_experts, hidden_dim, intermediate_dim, dtype, device); +// } + +// Tensor GuMoeSparseMoeBlock::forward(const Tensor& hidden_states) { +// auto input_shape = hidden_states->shape(); +// size_t batch_size = input_shape[0]; +// size_t seq_len = input_shape[1]; +// size_t hidden_dim = input_shape[2]; +// size_t total_tokens = hidden_states->numel() / hidden_dim; +// Tensor hidden_states_reshaped = hidden_states->view({total_tokens, hidden_dim}); +// auto [routing_weights, selected_experts] = router_->forward(hidden_states_reshaped); +// Tensor final_hidden_states = experts_->forward(hidden_states_reshaped, selected_experts, routing_weights); +// return final_hidden_states->view({batch_size, seq_len, hidden_dim}); +// } + +// } // namespace infinicore::nn + +// #include "gu_moe.h" + +// #include +// #include +// #include +// #include +// #include + +// #include "src/nvidia_kernels/nvidia_kernels_moe.h" +// #include "infinicore/ops.hpp" +// #include "infinirt.h" +// #include "infiniop.h" +// #include "gu_mul.h" +// #include "gu_topk_softmax.h" + +// namespace infinicore::nn { + +// // ========================================== +// // GuMoeTopkRounter 实现 +// // ========================================== + +// GuMoeTopkRounter::GuMoeTopkRounter(int num_experts, int hidden_dim, int top_k, bool norm_topk_prob, const DataType &dtype, const Device &device) +// : top_k_(top_k), +// num_experts_(num_experts), +// hidden_dim_(hidden_dim), +// norm_topk_prob_(norm_topk_prob) +// { +// infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); +// infiniopCreateHandle(&this->handle_); +// INFINICORE_NN_PARAMETER_INIT(weight, ({ {static_cast(num_experts_), static_cast(hidden_dim_)}, dtype, device })); +// } + +// GuMoeTopkRounter::~GuMoeTopkRounter() { +// if (handle_) infiniopDestroyHandle(handle_); +// } + +// std::pair GuMoeTopkRounter::forward(const Tensor &hidden_states) const { +// size_t total_tokens = hidden_states->numel() / hidden_dim_; +// Tensor flattened = hidden_states->view({total_tokens, static_cast(hidden_dim_)}); +// Tensor logits = infinicore::op::linear(flattened, weight_, std::nullopt); +// auto [val, idx] = infinicore::op::topk_softmax(logits, top_k_, norm_topk_prob_, this->handle_); +// return {val, idx}; +// } + +// // ========================================== +// // GuMoeExperts 实现 +// // ========================================== + +// GuMoeExperts::GuMoeExperts(int num_experts, int hidden_dim, int intermediate_dim, const DataType& dtype, const Device& device) +// : num_experts_(num_experts), +// hidden_dim_(hidden_dim), +// intermediate_dim_(intermediate_dim), +// device_(device) +// { +// infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); +// infiniopCreateHandle(&this->handle_); +// INFINICORE_NN_PARAMETER_INIT(gate_up_proj, ({ {static_cast(num_experts), static_cast(2 * intermediate_dim), static_cast(hidden_dim)}, dtype, device })); +// INFINICORE_NN_PARAMETER_INIT(down_proj, ({ {static_cast(num_experts), static_cast(hidden_dim), static_cast(intermediate_dim)}, dtype, device })); +// } + +// GuMoeExperts::~GuMoeExperts() { +// if (handle_) infiniopDestroyHandle(handle_); +// } + +// Tensor GuMoeExperts::forward(const Tensor& hidden_states, const Tensor& top_k_index, const Tensor& top_k_values) const { +// if (hidden_states->dtype() != DataType::F32) throw std::runtime_error("F32 only"); + +// Device device = hidden_states->device(); +// cudaStream_t stream = 0; + +// size_t num_tokens = hidden_states->numel() / hidden_dim_; +// int top_k = top_k_index->shape()[1]; +// size_t expanded_size = num_tokens * top_k; + +// // 1. 分配 Workspace (这些是持久的) +// Tensor expert_counts = Tensor::zeros({(size_t)num_experts_}, DataType::I32, device); +// Tensor expert_offsets = Tensor::zeros({(size_t)num_experts_ + 1}, DataType::I32, device); +// Tensor sorted_input = Tensor::empty({expanded_size, (size_t)hidden_dim_}, DataType::F32, device); +// Tensor sorted_output = Tensor::empty({expanded_size, (size_t)hidden_dim_}, DataType::F32, device); +// Tensor sorted_row_map = Tensor::empty({expanded_size}, DataType::I32, device); +// Tensor sorted_weights = Tensor::empty({expanded_size}, DataType::F32, device); +// Tensor final_output = Tensor::zeros(hidden_states->shape(), DataType::F32, device); + +// float* d_input = (float*)hidden_states->data(); +// int32_t* d_indices = (int32_t*)top_k_index->data(); +// float* d_values = (float*)top_k_values->data(); +// int32_t* d_counts = (int32_t*)expert_counts->data(); +// int32_t* d_offsets = (int32_t*)expert_offsets->data(); + +// launch_moe_sort(d_indices, d_counts, d_offsets, num_tokens, top_k, num_experts_, stream); +// launch_moe_permute( +// d_input, d_indices, d_values, d_offsets, +// (float*)sorted_input->data(), (int32_t*)sorted_row_map->data(), (float*)sorted_weights->data(), +// d_counts, num_tokens, top_k, hidden_dim_, num_experts_, stream +// ); + +// // 2. 拷贝 Offset 必须同步,否则后面循环会乱 +// std::vector h_offsets(num_experts_ + 1); +// cudaMemcpy(h_offsets.data(), d_offsets, sizeof(int32_t) * (num_experts_ + 1), cudaMemcpyDeviceToHost); + +// // 3. 专家循环:使用大括号控制局部变量生命周期 +// for (int e = 0; e < num_experts_; ++e) { +// int start_idx = h_offsets[e]; +// int count = h_offsets[e+1] - start_idx; +// if (count <= 0) continue; + +// { +// // 在这个大括号内定义的 Tensor 会在每一轮迭代结束时立即析构 +// // 这能强制让 cudaMallocAsync 知道这块内存可以回收了 +// Tensor expert_in = sorted_input->narrow({{0, (size_t)start_idx, (size_t)count}}); +// Tensor w_gate_up = gate_up_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)(2*intermediate_dim_), (size_t)hidden_dim_}); +// Tensor w_down = down_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)hidden_dim_, (size_t)intermediate_dim_}); + +// Tensor gate_up_out = infinicore::op::linear(expert_in, w_gate_up, std::nullopt); +// Tensor gate = gate_up_out->narrow({{1, 0, (size_t)intermediate_dim_}}); +// Tensor up = gate_up_out->narrow({{1, (size_t)intermediate_dim_, (size_t)intermediate_dim_}}); + +// Tensor ffn_inner = infinicore::op::mul(infinicore::op::silu(gate), up, this->handle_); +// Tensor expert_res = infinicore::op::linear(ffn_inner, w_down, std::nullopt); + +// float* dst_ptr = (float*)sorted_output->data() + start_idx * hidden_dim_; +// cudaMemcpyAsync(dst_ptr, (float*)expert_res->data(), count * hidden_dim_ * sizeof(float), cudaMemcpyDeviceToDevice, stream); +// } // <--- 关键:在这里,上一轮的所有中间 Tensor 都会被释放 +// } + +// launch_moe_reduce( +// (float*)sorted_output->data(), (int32_t*)sorted_row_map->data(), (float*)sorted_weights->data(), +// (float*)final_output->data(), num_tokens, top_k, hidden_dim_, stream +// ); + +// // 4. 最终同步:解决全零问题的关键 +// cudaStreamSynchronize(stream); + +// return final_output; +// } + +// // ========================================== +// // GuMoeSparseMoeBlock 实现 +// // ========================================== + +// GuMoeSparseMoeBlock::GuMoeSparseMoeBlock(int num_experts, int hidden_dim, int intermediate_dim, +// int top_k, bool norm_topk, +// const DataType& dtype, const Device& device) { +// router_ = register_module("router", num_experts, hidden_dim, top_k, norm_topk, dtype, device); +// experts_ = register_module("experts", num_experts, hidden_dim, intermediate_dim, dtype, device); +// } + +// Tensor GuMoeSparseMoeBlock::forward(const Tensor& hidden_states) { +// auto input_shape = hidden_states->shape(); +// size_t batch_size = input_shape[0]; +// size_t seq_len = input_shape[1]; +// size_t hidden_dim = input_shape[2]; +// size_t total_tokens = hidden_states->numel() / hidden_dim; +// Tensor hidden_states_reshaped = hidden_states->view({total_tokens, hidden_dim}); +// auto [routing_weights, selected_experts] = router_->forward(hidden_states_reshaped); +// Tensor final_hidden_states = experts_->forward(hidden_states_reshaped, selected_experts, routing_weights); +// return final_hidden_states->view({batch_size, seq_len, hidden_dim}); +// } + +//} // namespace infinicore::nn + +// #include "gu_moe.h" +// #include +// #include +// #include +// #include +// #include + +// #include "src/nvidia_kernels/nvidia_kernels_moe.h" +// #include "infinicore/ops.hpp" +// #include "infinirt.h" +// #include "infiniop.h" +// #include "gu_mul.h" +// #include "gu_topk_softmax.h" + +// // 尝试引入框架的流获取接口 +// namespace infinicore::context { +// extern void* getStream(); +// } + +// namespace infinicore::nn { + +// // ========================================== +// // GuMoeTopkRounter +// // ========================================== +// GuMoeTopkRounter::GuMoeTopkRounter(int num_experts, int hidden_dim, int top_k, bool norm_topk_prob, const DataType &dtype, const Device &device) +// : top_k_(top_k), num_experts_(num_experts), hidden_dim_(hidden_dim), norm_topk_prob_(norm_topk_prob) { +// infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); +// infiniopCreateHandle(&this->handle_); +// INFINICORE_NN_PARAMETER_INIT(weight, ({ {static_cast(num_experts_), static_cast(hidden_dim_)}, dtype, device })); +// } + +// GuMoeTopkRounter::~GuMoeTopkRounter() { if (handle_) infiniopDestroyHandle(handle_); } + +// std::pair GuMoeTopkRounter::forward(const Tensor &hidden_states) const { +// size_t total_tokens = hidden_states->numel() / hidden_dim_; +// Tensor flattened = hidden_states->view({total_tokens, static_cast(hidden_dim_)}); +// Tensor logits = infinicore::op::linear(flattened, weight_, std::nullopt); +// auto [val, idx] = infinicore::op::topk_softmax(logits, top_k_, norm_topk_prob_, this->handle_); +// return {val, idx}; +// } + +// // ========================================== +// // GuMoeExperts +// // ========================================== +// GuMoeExperts::GuMoeExperts(int num_experts, int hidden_dim, int intermediate_dim, const DataType& dtype, const Device& device) +// : num_experts_(num_experts), hidden_dim_(hidden_dim), intermediate_dim_(intermediate_dim), device_(device) { +// infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); +// infiniopCreateHandle(&this->handle_); +// INFINICORE_NN_PARAMETER_INIT(gate_up_proj, ({ {static_cast(num_experts), static_cast(2 * intermediate_dim), static_cast(hidden_dim)}, dtype, device })); +// INFINICORE_NN_PARAMETER_INIT(down_proj, ({ {static_cast(num_experts), static_cast(hidden_dim), static_cast(intermediate_dim)}, dtype, device })); +// } + +// GuMoeExperts::~GuMoeExperts() { if (handle_) infiniopDestroyHandle(handle_); } + +// Tensor GuMoeExperts::forward(const Tensor& hidden_states, const Tensor& top_k_index, const Tensor& top_k_values) const { +// Device device = hidden_states->device(); +// // 使用框架流,如果没有则退回到默认流 0 +// void* raw_stream = infinicore::context::getStream(); +// cudaStream_t stream = (cudaStream_t)raw_stream; //? (cudaStream_t)raw_stream : (cudaStream_t)0; + +// size_t num_tokens = hidden_states->numel() / hidden_dim_; +// int top_k = top_k_index->shape()[1]; +// size_t expanded_size = num_tokens * top_k; + +// // 分配 Workspace +// Tensor expert_counts = Tensor::zeros({(size_t)num_experts_}, DataType::I32, device); +// Tensor expert_offsets = Tensor::zeros({(size_t)num_experts_ + 1}, DataType::I32, device); +// Tensor sorted_input = Tensor::empty({expanded_size, (size_t)hidden_dim_}, DataType::F32, device); +// Tensor sorted_output = Tensor::empty({expanded_size, (size_t)hidden_dim_}, DataType::F32, device); +// Tensor sorted_row_map = Tensor::empty({expanded_size}, DataType::I32, device); +// Tensor sorted_weights = Tensor::empty({expanded_size}, DataType::F32, device); +// Tensor final_output = Tensor::zeros(hidden_states->shape(), DataType::F32, device); + +// // Phase 1 +// launch_moe_sort((int32_t*)top_k_index->data(), (int32_t*)expert_counts->data(), (int32_t*)expert_offsets->data(), num_tokens, top_k, num_experts_, stream); +// launch_moe_permute((float*)hidden_states->data(), (int32_t*)top_k_index->data(), (float*)top_k_values->data(), (int32_t*)expert_offsets->data(), +// (float*)sorted_input->data(), (int32_t*)sorted_row_map->data(), (float*)sorted_weights->data(), +// (int32_t*)expert_counts->data(), num_tokens, top_k, hidden_dim_, num_experts_, stream); + +// // Phase 2 +// std::vector h_offsets(num_experts_ + 1); +// cudaMemcpy(h_offsets.data(), expert_offsets->data(), sizeof(int32_t) * (num_experts_ + 1), cudaMemcpyDeviceToHost); + +// for (int e = 0; e < num_experts_; ++e) { +// int start_idx = h_offsets[e]; +// int count = h_offsets[e+1] - start_idx; +// if (count <= 0) continue; + +// { // 局部作用域回收显存 +// Tensor expert_in = sorted_input->narrow({{0, (size_t)start_idx, (size_t)count}}); +// Tensor w_gate_up = gate_up_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)(2*intermediate_dim_), (size_t)hidden_dim_}); +// Tensor w_down = down_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)hidden_dim_, (size_t)intermediate_dim_}); + +// Tensor gate_up_out = infinicore::op::linear(expert_in, w_gate_up, std::nullopt); +// Tensor gate = gate_up_out->narrow({{1, 0, (size_t)intermediate_dim_}}); +// Tensor up = gate_up_out->narrow({{1, (size_t)intermediate_dim_, (size_t)intermediate_dim_}}); +// Tensor ffn_inner = infinicore::op::mul(infinicore::op::silu(gate), up, this->handle_); +// Tensor expert_res = infinicore::op::linear(ffn_inner, w_down, std::nullopt); + +// cudaMemcpyAsync((float*)sorted_output->data() + start_idx * hidden_dim_, (float*)expert_res->data(), count * hidden_dim_ * sizeof(float), cudaMemcpyDeviceToDevice, stream); +// } +// } + +// // Phase 3 +// launch_moe_reduce((float*)sorted_output->data(), (int32_t*)sorted_row_map->data(), (float*)sorted_weights->data(), (float*)final_output->data(), num_tokens, top_k, hidden_dim_, stream); + +// cudaStreamSynchronize(stream); +// return final_output; +// } + +// // ========================================== +// // GuMoeSparseMoeBlock +// // ========================================== +// GuMoeSparseMoeBlock::GuMoeSparseMoeBlock(int num_experts, int hidden_dim, int intermediate_dim, int top_k, bool norm_topk, const DataType& dtype, const Device& device) { +// router_ = register_module("router", num_experts, hidden_dim, top_k, norm_topk, dtype, device); +// experts_ = register_module("experts", num_experts, hidden_dim, intermediate_dim, dtype, device); +// } + +// Tensor GuMoeSparseMoeBlock::forward(const Tensor& hidden_states) { +// size_t total_tokens = hidden_states->numel() / (hidden_states->shape().back()); +// Tensor hidden_states_reshaped = hidden_states->view({total_tokens, hidden_states->shape().back()}); +// auto [routing_weights, selected_experts] = router_->forward(hidden_states_reshaped); +// Tensor final_hidden_states = experts_->forward(hidden_states_reshaped, selected_experts, routing_weights); +// return final_hidden_states->view(hidden_states->shape()); +// } + +// } // namespace infinicore::nn + +// #include "gu_moe.h" +// #include +// #include +// #include +// #include +// #include +// #include + +// #include "src/nvidia_kernels/nvidia_kernels_moe.h" +// #include "infinicore/ops.hpp" +// #include "infinirt.h" +// #include "infiniop.h" +// #include "gu_mul.h" +// #include "gu_topk_softmax.h" + +// namespace infinicore::nn { + +// // ========================================== +// // GuMoeTopkRounter +// // ========================================== +// GuMoeTopkRounter::GuMoeTopkRounter(int num_experts, int hidden_dim, int top_k, bool norm_topk_prob, const DataType &dtype, const Device &device) +// : top_k_(top_k), num_experts_(num_experts), hidden_dim_(hidden_dim), norm_topk_prob_(norm_topk_prob) { +// infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); +// infiniopCreateHandle(&this->handle_); +// INFINICORE_NN_PARAMETER_INIT(weight, ({ {static_cast(num_experts_), static_cast(hidden_dim_)}, dtype, device })); +// } + +// GuMoeTopkRounter::~GuMoeTopkRounter() { if (handle_) infiniopDestroyHandle(handle_); } + +// std::pair GuMoeTopkRounter::forward(const Tensor &hidden_states) const { +// size_t total_tokens = hidden_states->numel() / hidden_dim_; +// Tensor flattened = hidden_states->view({total_tokens, static_cast(hidden_dim_)}); +// Tensor logits = infinicore::op::linear(flattened, weight_, std::nullopt); +// auto [val, idx] = infinicore::op::topk_softmax(logits, top_k_, norm_topk_prob_, this->handle_); +// return {val, idx}; +// } + +// // ========================================== +// // GuMoeExperts +// // ========================================== +// GuMoeExperts::GuMoeExperts(int num_experts, int hidden_dim, int intermediate_dim, const DataType& dtype, const Device& device) +// : num_experts_(num_experts), +// hidden_dim_(hidden_dim), +// intermediate_dim_(intermediate_dim), +// device_(device) +// { +// // --- 增加这一段强力打印 --- +// printf("\n[CONSTRUCTOR_DEBUG] num_experts: %d, hidden: %d, inter: %d\n", +// num_experts, hidden_dim, intermediate_dim); +// fflush(stdout); + +// if (num_experts <= 0 || hidden_dim <= 0 || intermediate_dim <= 0) { +// printf("[FATAL] Invalid dimensions detected!\n"); +// fflush(stdout); +// } +// // ------------------------- + +// infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); +// infiniopCreateHandle(&this->handle_); + +// INFINICORE_NN_PARAMETER_INIT(gate_up_proj, ({ {static_cast(num_experts), static_cast(2 * intermediate_dim), static_cast(hidden_dim)}, dtype, device })); +// INFINICORE_NN_PARAMETER_INIT(down_proj, ({ {static_cast(num_experts), static_cast(hidden_dim), static_cast(intermediate_dim)}, dtype, device })); +// } + +// GuMoeExperts::~GuMoeExperts() { if (handle_) infiniopDestroyHandle(handle_); } + +// Tensor GuMoeExperts::forward(const Tensor& hidden_states, const Tensor& top_k_index, const Tensor& top_k_values) const { +// Device device = hidden_states->device(); +// cudaStream_t stream = 0; + +// size_t num_tokens = hidden_states->numel() / hidden_dim_; +// int top_k = top_k_index->shape()[1]; +// size_t expanded_size = (size_t)num_tokens * top_k; + +// // 1. 显式分配 Workspace +// Tensor expert_counts = Tensor::zeros({(size_t)num_experts_}, DataType::I32, device); +// Tensor expert_offsets = Tensor::zeros({(size_t)num_experts_ + 1}, DataType::I32, device); +// Tensor sorted_input = Tensor::empty({expanded_size, (size_t)hidden_dim_}, DataType::F32, device); +// Tensor sorted_output = Tensor::empty({expanded_size, (size_t)hidden_dim_}, DataType::F32, device); +// Tensor sorted_row_map = Tensor::empty({expanded_size}, DataType::I32, device); +// Tensor sorted_weights = Tensor::empty({expanded_size}, DataType::F32, device); +// Tensor final_output = Tensor::zeros(hidden_states->shape(), DataType::F32, device); + +// // Phase 1: 数据重排 +// launch_moe_sort((int32_t*)top_k_index->data(), (int32_t*)expert_counts->data(), (int32_t*)expert_offsets->data(), (int)num_tokens, top_k, num_experts_, stream); +// launch_moe_permute((float*)hidden_states->data(), (int32_t*)top_k_index->data(), (float*)top_k_values->data(), (int32_t*)expert_offsets->data(), +// (float*)sorted_input->data(), (int32_t*)sorted_row_map->data(), (float*)sorted_weights->data(), +// (int32_t*)expert_counts->data(), (int)num_tokens, top_k, hidden_dim_, num_experts_, stream); + +// // Phase 2: 计算循环 +// std::vector h_offsets(num_experts_ + 1); +// cudaMemcpy(h_offsets.data(), expert_offsets->data(), sizeof(int32_t) * (num_experts_ + 1), cudaMemcpyDeviceToHost); + +// for (int e = 0; e < num_experts_; ++e) { +// int start_idx = h_offsets[e]; +// int count = h_offsets[e+1] - start_idx; +// if (count <= 0) continue; + +// { // 利用作用域自动析构临时 Tensor,释放显存池 +// Tensor expert_in = sorted_input->narrow({{0, (size_t)start_idx, (size_t)count}}); +// Tensor w_gate_up = gate_up_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)(2*intermediate_dim_), (size_t)hidden_dim_}); +// Tensor w_down = down_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)hidden_dim_, (size_t)intermediate_dim_}); + +// // 执行 FFN +// Tensor gate_up_out = infinicore::op::linear(expert_in, w_gate_up, std::nullopt); +// Tensor gate = gate_up_out->narrow({{1, 0, (size_t)intermediate_dim_}}); +// Tensor up = gate_up_out->narrow({{1, (size_t)intermediate_dim_, (size_t)intermediate_dim_}}); + +// Tensor activated_gate = infinicore::op::silu(gate); +// Tensor ffn_inner = infinicore::op::mul(activated_gate, up, this->handle_); +// Tensor expert_res = infinicore::op::linear(ffn_inner, w_down, std::nullopt); + +// cudaMemcpyAsync((float*)sorted_output->data() + start_idx * hidden_dim_, (float*)expert_res->data(), (size_t)count * hidden_dim_ * sizeof(float), cudaMemcpyDeviceToDevice, stream); +// } // 此处局部 Tensor 自动析构 +// } + +// // Phase 3: 结果规约 +// launch_moe_reduce((float*)sorted_output->data(), (int32_t*)sorted_row_map->data(), (float*)sorted_weights->data(), (float*)final_output->data(), (int)num_tokens, top_k, hidden_dim_, stream); + +// cudaStreamSynchronize(stream); +// return final_output; +// } + +// // ========================================== +// // GuMoeSparseMoeBlock +// // ========================================== +// GuMoeSparseMoeBlock::GuMoeSparseMoeBlock(int num_experts, int hidden_dim, int intermediate_dim, int top_k, bool norm_topk, const DataType& dtype, const Device& device) { +// router_ = register_module("router", num_experts, hidden_dim, top_k, norm_topk, dtype, device); +// experts_ = register_module("experts", num_experts, hidden_dim, intermediate_dim, dtype, device); +// } + +// Tensor GuMoeSparseMoeBlock::forward(const Tensor& hidden_states) { +// auto shp = hidden_states->shape(); +// size_t last_dim = shp.back(); +// size_t total_tokens = hidden_states->numel() / last_dim; + +// Tensor hidden_states_reshaped = hidden_states->view({total_tokens, last_dim}); +// auto [routing_weights, selected_experts] = router_->forward(hidden_states_reshaped); +// Tensor final_hidden_states = experts_->forward(hidden_states_reshaped, selected_experts, routing_weights); + +// return final_hidden_states->view(shp); +// } + +// } // namespace infinicore::nn + +// #include "gu_moe.h" +// #include +// #include +// #include +// #include +// #include +// #include + +// #include "src/nvidia_kernels/nvidia_kernels_moe.h" +// #include "infinicore/ops.hpp" +// #include "infinirt.h" +// #include "infiniop.h" +// #include "gu_mul.h" +// #include "gu_topk_softmax.h" + +// // 引入框架流接口 +// namespace infinicore::context { +// extern void* getStream(); +// } + +// namespace infinicore::nn { + +// // GuMoeTopkRounter (保持不变) +// GuMoeTopkRounter::GuMoeTopkRounter(int num_experts, int hidden_dim, int top_k, bool norm_topk_prob, const DataType &dtype, const Device &device) +// : top_k_(top_k), num_experts_(num_experts), hidden_dim_(hidden_dim), norm_topk_prob_(norm_topk_prob) { +// infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); +// infiniopCreateHandle(&this->handle_); +// INFINICORE_NN_PARAMETER_INIT(weight, ({ {static_cast(num_experts_), static_cast(hidden_dim_)}, dtype, device })); +// } +// GuMoeTopkRounter::~GuMoeTopkRounter() { if (handle_) infiniopDestroyHandle(handle_); } +// std::pair GuMoeTopkRounter::forward(const Tensor &hidden_states) const { +// size_t total_tokens = hidden_states->numel() / hidden_dim_; +// Tensor flattened = hidden_states->view({total_tokens, static_cast(hidden_dim_)}); +// Tensor logits = infinicore::op::linear(flattened, weight_, std::nullopt); +// auto [val, idx] = infinicore::op::topk_softmax(logits, top_k_, norm_topk_prob_, this->handle_); +// return {val, idx}; +// } + +// // GuMoeExperts (保持不变) +// GuMoeExperts::GuMoeExperts(int num_experts, int hidden_dim, int intermediate_dim, const DataType& dtype, const Device& device) +// : num_experts_(num_experts), hidden_dim_(hidden_dim), intermediate_dim_(intermediate_dim), device_(device) { +// infinirtSetDevice((infiniDevice_t)device.getType(), device.getIndex()); +// infiniopCreateHandle(&this->handle_); +// INFINICORE_NN_PARAMETER_INIT(gate_up_proj, ({ {static_cast(num_experts), static_cast(2 * intermediate_dim), static_cast(hidden_dim)}, dtype, device })); +// INFINICORE_NN_PARAMETER_INIT(down_proj, ({ {static_cast(num_experts), static_cast(hidden_dim), static_cast(intermediate_dim)}, dtype, device })); +// } +// GuMoeExperts::~GuMoeExperts() { if (handle_) infiniopDestroyHandle(handle_); } + +// Tensor GuMoeExperts::forward(const Tensor& hidden_states, const Tensor& top_k_index, const Tensor& top_k_values) const { +// Device device = hidden_states->device(); +// void* raw_stream = infinicore::context::getStream(); +// cudaStream_t stream = raw_stream ? (cudaStream_t)raw_stream : 0; + +// // 回退类型转换,直接使用原始指针,但在 count 计算处做防御 +// size_t num_tokens = hidden_states->numel() / hidden_dim_; +// int top_k = top_k_index->shape()[1]; +// size_t expanded_size = num_tokens * top_k; + +// Tensor expert_counts = Tensor::zeros({(size_t)num_experts_}, DataType::I32, device); +// Tensor expert_offsets = Tensor::zeros({(size_t)num_experts_ + 1}, DataType::I32, device); +// Tensor sorted_input = Tensor::empty({expanded_size, (size_t)hidden_dim_}, DataType::F32, device); +// Tensor sorted_output = Tensor::empty({expanded_size, (size_t)hidden_dim_}, DataType::F32, device); +// Tensor sorted_row_map = Tensor::empty({expanded_size}, DataType::I32, device); +// Tensor sorted_weights = Tensor::empty({expanded_size}, DataType::F32, device); +// Tensor final_output = Tensor::zeros(hidden_states->shape(), DataType::F32, device); + +// launch_moe_sort( +// (int32_t*)top_k_index->data(), +// (int32_t*)expert_counts->data(), +// (int32_t*)expert_offsets->data(), +// (int)num_tokens, top_k, num_experts_, stream +// ); + +// launch_moe_permute( +// (float*)hidden_states->data(), +// (int32_t*)top_k_index->data(), +// (float*)top_k_values->data(), +// (int32_t*)expert_offsets->data(), +// (float*)sorted_input->data(), +// (int32_t*)sorted_row_map->data(), +// (float*)sorted_weights->data(), +// (int32_t*)expert_counts->data(), +// (int)num_tokens, top_k, hidden_dim_, num_experts_, stream +// ); + +// std::vector h_offsets(num_experts_ + 1); +// cudaMemcpy(h_offsets.data(), expert_offsets->data(), sizeof(int32_t) * (num_experts_ + 1), cudaMemcpyDeviceToHost); + +// for (int e = 0; e < num_experts_; ++e) { +// int start_idx = h_offsets[e]; +// int count = h_offsets[e+1] - start_idx; + +// // 【核心防御】防止 Error 700 / OOM +// // 如果 count 异常(可能是由于 Int64/32 读取错位导致的),直接跳过! +// if (count <= 0 || count > (int)expanded_size) { +// if (count > (int)expanded_size) { +// printf("WARNING: Expert %d skipped due to invalid count: %d\n", e, count); +// } +// printf("WARNING: Expert %d skipped due to invalid count: %d\n", e, count); +// continue; +// } + +// { +// Tensor expert_in = sorted_input->narrow({{0, (size_t)start_idx, (size_t)count}}); +// Tensor w_gate_up = gate_up_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)(2*intermediate_dim_), (size_t)hidden_dim_}); +// Tensor w_down = down_proj_->narrow({{0, (size_t)e, 1}})->view({(size_t)hidden_dim_, (size_t)intermediate_dim_}); + +// Tensor gate_up_out = infinicore::op::linear(expert_in, w_gate_up, std::nullopt); +// Tensor gate = gate_up_out->narrow({{1, 0, (size_t)intermediate_dim_}}); +// Tensor up = gate_up_out->narrow({{1, (size_t)intermediate_dim_, (size_t)intermediate_dim_}}); + +// Tensor ffn_inner = infinicore::op::mul(infinicore::op::silu(gate), up, this->handle_); +// Tensor expert_res = infinicore::op::linear(ffn_inner, w_down, std::nullopt); + +// cudaMemcpyAsync((float*)sorted_output->data() + start_idx * hidden_dim_, (float*)expert_res->data(), (size_t)count * hidden_dim_ * sizeof(float), cudaMemcpyDeviceToDevice, stream); +// } +// } + +// launch_moe_reduce((float*)sorted_output->data(), (int32_t*)sorted_row_map->data(), (float*)sorted_weights->data(), (float*)final_output->data(), (int)num_tokens, top_k, hidden_dim_, stream); + +// cudaStreamSynchronize(stream); +// return final_output; +// } + +// // GuMoeSparseMoeBlock (保持不变) +// GuMoeSparseMoeBlock::GuMoeSparseMoeBlock(int num_experts, int hidden_dim, int intermediate_dim, int top_k, bool norm_topk, const DataType& dtype, const Device& device) { +// router_ = register_module("router", num_experts, hidden_dim, top_k, norm_topk, dtype, device); +// experts_ = register_module("experts", num_experts, hidden_dim, intermediate_dim, dtype, device); +// } +// Tensor GuMoeSparseMoeBlock::forward(const Tensor& hidden_states) { +// size_t total_tokens = hidden_states->numel() / (hidden_states->shape().back()); +// Tensor hidden_states_reshaped = hidden_states->view({total_tokens, hidden_states->shape().back()}); +// auto [routing_weights, selected_experts] = router_->forward(hidden_states_reshaped); +// Tensor final_hidden_states = experts_->forward(hidden_states_reshaped, selected_experts, routing_weights); +// return final_hidden_states->view(hidden_states->shape()); +// } + +// } // namespace infinicore::nn + +// #include +// #include +// #include +// #include + +// #define MAX_EXPERTS 256 + +// #define CUDA_CHECK(call) \ +// do { \ +// cudaError_t error = call; \ +// if (error != cudaSuccess) { \ +// fprintf(stderr, "CUDA Error at line %d: %s\n", __LINE__, cudaGetErrorString(error)); \ +// exit(1); \ +// } \ +// } while(0) + +// __global__ void count_kernel_sota( +// const int32_t* __restrict__ topk_ids, +// int32_t* __restrict__ expert_counts, +// int total_tasks, +// int num_experts +// ) { +// extern __shared__ int32_t smem_counts[]; + +// int tid = threadIdx.x; +// int bid = blockIdx.x; +// int gid = bid * blockDim.x + tid; + +// for (int i = tid; i < num_experts; i += blockDim.x) { +// smem_counts[i] = 0; +// } +// __syncthreads(); + +// if (gid < total_tasks) { +// int expert_id = topk_ids[gid]; + +// unsigned int active_mask = __activemask(); +// unsigned int mask = __match_any_sync(active_mask, expert_id); + +// int leader = __ffs(mask) - 1; // Find First Set +// int lane_id = tid % 32; + +// if (lane_id == leader) { + +// int agg_count = __popc(mask); + +// atomicAdd(&smem_counts[expert_id], agg_count); +// } +// } + +// __syncthreads(); + +// for (int i = tid; i < num_experts; i += blockDim.x) { +// int count = smem_counts[i]; +// if (count > 0) { +// atomicAdd(&expert_counts[i], count); +// } +// } +// } + +// void launch_moe_sort( +// const int32_t* topk_ids, +// int32_t* expert_counts, +// int32_t* expert_offsets, // 长度建议是 num_experts + 1 +// int num_tokens, +// int top_k, +// int num_experts, +// cudaStream_t stream +// ) { +// int total_tasks = num_tokens * top_k; +// int block_size = 256; +// int grid_size = (total_tasks + block_size - 1) / block_size; + +// // ------------------------------------------------------- +// CUDA_CHECK(cudaMemsetAsync(expert_counts, 0, num_experts * sizeof(int32_t), stream)); + +// count_kernel_sota<<>>( +// topk_ids, expert_counts, total_tasks, num_experts +// ); + +// void* d_temp_storage = NULL; +// size_t temp_storage_bytes = 0; + +// cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, +// expert_counts, // 输入: counts +// expert_offsets, // 输出: offsets +// num_experts + 1,// 长度: 多算一位作为总和 +// stream); + +// CUDA_CHECK(cudaMallocAsync(&d_temp_storage, temp_storage_bytes, stream)); + +// // 执行 +// cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, +// expert_counts, +// expert_offsets, +// num_experts + 1, +// stream); + +// CUDA_CHECK(cudaFreeAsync(d_temp_storage, stream)); +// } + +// __global__ void permute_kernel( +// const float* __restrict__ input, // [N, H] 源数据 +// const int32_t* __restrict__ topk_ids, // [N, K] 路由 +// const float* __restrict__ topk_weights, +// const int32_t* __restrict__ expert_offsets,// [E] 起始位置 +// int32_t* __restrict__ running_counters, // [E] 临时计数器 (原子加专用) +// float* __restrict__ sorted_input, // [N*K, H] 目标数据 +// int32_t* __restrict__ sorted_row_map, // [N*K] 来源记录 +// float* __restrict__ sorted_weights, +// int num_tokens, +// int top_k, +// int hidden_dim +// ) { +// // 任务总数 = Token数 * TopK (因为可能有复制) +// int total_tasks = num_tokens * top_k; +// int tid = blockIdx.x * blockDim.x + threadIdx.x; + +// if (tid >= total_tasks) return; + +// int token_idx = tid / top_k; +// // int k_idx = tid % top_k; // 如果 weights 是 [N, K] 布局,需要用这个 +// int expert_id = topk_ids[tid]; + +// // 获取该专家的起始地址 +// int base_offset = expert_offsets[expert_id]; +// // 原子获取我是该专家的第几个客人 +// int my_rank = atomicAdd(&running_counters[expert_id], 1); +// // 计算最终写入的行号 +// int target_row = base_offset + my_rank; + +// // 记下:第 target_row 行数据,其实是原来的 token_idx +// sorted_row_map[target_row] = token_idx; +// sorted_weights[target_row] = topk_weights[tid]; + +// // 从 input[token_idx] 搬到 sorted_input[target_row] +// const float* src_ptr = input + token_idx * hidden_dim; +// float* dst_ptr = sorted_input + target_row * hidden_dim; + +// // 尝试使用 float4 (128-bit) 进行搬运,减少指令数 +// int vec_size = hidden_dim / 4; +// int remainder = hidden_dim % 4; + +// // 强转指针进行向量化读取 +// const float4* src_vec = (const float4*)src_ptr; +// float4* dst_vec = (float4*)dst_ptr; + +// for (int i = 0; i < vec_size; ++i) { +// dst_vec[i] = src_vec[i]; +// } +// // 处理剩下的尾巴 (如果有的话) +// for (int i = 0; i < remainder; ++i) { +// int idx = vec_size * 4 + i; +// dst_ptr[idx] = src_ptr[idx]; +// } +// } + +// void launch_moe_permute( +// const float* input, +// const int32_t* topk_ids, +// const float* topk_weights, +// const int32_t* expert_offsets, +// float* sorted_input, +// int32_t* sorted_row_map, +// float* sorted_weights, +// int32_t* expert_counts, // <--- 复用这个数组作为临时计数器 +// int num_tokens, +// int top_k, +// int hidden_dim, +// int num_experts, +// cudaStream_t stream +// ) { +// int total_tasks = num_tokens * top_k; +// int block_size = 256; +// int grid_size = (total_tasks + block_size - 1) / block_size; + +// // 1. 【关键】把计数器重置为 0 +// // 这样每个专家才能从第 0 个开始数 +// CUDA_CHECK(cudaMemsetAsync(expert_counts, 0, num_experts * sizeof(int32_t), stream)); + +// // 2. 启动 Kernel +// permute_kernel<<>>( +// input, +// topk_ids, +// topk_weights, +// expert_offsets, +// expert_counts, // 这里传进去当作 running_counters 用 +// sorted_input, +// sorted_row_map, +// sorted_weights, +// num_tokens, +// top_k, +// hidden_dim +// ); +// } + +// #include +// #include +// #include +// #include + +// #define MAX_EXPERTS 256 + +// // 增强版 Check 宏 +// #define CUDA_CHECK(call) \ +// do { \ +// cudaError_t error = call; \ +// if (error != cudaSuccess) { \ +// fprintf(stderr, "[KERNEL ERROR] %s failed at line %d: %s\n", #call, __LINE__, cudaGetErrorString(error)); \ +// exit(1); \ +// } \ +// } while(0) + +// // ============================================================= +// // 1. Count Kernel (统计每个专家的 token 数) +// // ============================================================= +// __global__ void count_kernel_sota( +// const int32_t* __restrict__ topk_ids, +// int32_t* __restrict__ expert_counts, +// int total_tasks, +// int num_experts +// ) { +// extern __shared__ int32_t smem_counts[]; + +// int tid = threadIdx.x; +// int bid = blockIdx.x; +// int gid = bid * blockDim.x + tid; + +// // 初始化共享内存 +// for (int i = tid; i < num_experts; i += blockDim.x) { +// smem_counts[i] = 0; +// } +// __syncthreads(); + +// // 统计 +// if (gid < total_tasks) { +// int expert_id = topk_ids[gid]; +// // 简单的边界检查 +// if (expert_id >= 0 && expert_id < num_experts) { +// unsigned int active_mask = __activemask(); +// unsigned int mask = __match_any_sync(active_mask, expert_id); +// int leader = __ffs(mask) - 1; +// int lane_id = tid % 32; +// if (lane_id == leader) { +// int agg_count = __popc(mask); +// atomicAdd(&smem_counts[expert_id], agg_count); +// } +// } +// } +// __syncthreads(); + +// // 写回全局内存 +// for (int i = tid; i < num_experts; i += blockDim.x) { +// int count = smem_counts[i]; +// if (count > 0) { +// atomicAdd(&expert_counts[i], count); +// } +// } +// } + +// void launch_moe_sort( +// const int32_t* topk_ids, +// int32_t* expert_counts, +// int32_t* expert_offsets, +// int num_tokens, +// int top_k, +// int num_experts, +// cudaStream_t stream +// ) { +// int total_tasks = num_tokens * top_k; +// int block_size = 256; +// int grid_size = (total_tasks + block_size - 1) / block_size; + +// // 清零 Counts +// CUDA_CHECK(cudaMemsetAsync(expert_counts, 0, num_experts * sizeof(int32_t), stream)); + +// // 运行统计 +// count_kernel_sota<<>>( +// topk_ids, expert_counts, total_tasks, num_experts +// ); + +// // CUB Scan (前缀和) +// void* d_temp_storage = NULL; +// size_t temp_storage_bytes = 0; + +// // 查询所需显存 (注意 num_experts + 1 以计算总和) +// // 这里的 expert_counts 对应 gumoe.cpp 里申请的 (num_experts + 1) 大小,安全。 +// cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, +// expert_counts, expert_offsets, +// num_experts + 1, stream); + +// // ==================================================== +// // 【关键修改】使用同步 cudaMalloc +// // 必须替换掉原来的 cudaMallocAsync,否则在你的环境里会分配失败导致 Core Dump +// // ==================================================== +// CUDA_CHECK(cudaMalloc(&d_temp_storage, temp_storage_bytes)); + +// // 执行 Scan +// cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, +// expert_counts, expert_offsets, +// num_experts + 1, stream); + +// // 同步释放 +// CUDA_CHECK(cudaFree(d_temp_storage)); +// } + +// // ============================================================= +// // 2. Permute Kernel (重排数据) +// // ============================================================= +// __global__ void permute_kernel( +// const float* __restrict__ input, +// const int32_t* __restrict__ topk_ids, +// const float* __restrict__ topk_weights, +// const int32_t* __restrict__ expert_offsets, +// int32_t* __restrict__ running_counters, +// float* __restrict__ sorted_input, +// int32_t* __restrict__ sorted_row_map, +// float* __restrict__ sorted_weights, +// int num_tokens, +// int top_k, +// int hidden_dim +// ) { +// int total_tasks = num_tokens * top_k; +// int tid = blockIdx.x * blockDim.x + threadIdx.x; + +// if (tid >= total_tasks) return; + +// int token_idx = tid / top_k; +// int expert_id = topk_ids[tid]; + +// // 原子获取写入位置 +// int base_offset = expert_offsets[expert_id]; +// int my_rank = atomicAdd(&running_counters[expert_id], 1); +// int target_row = base_offset + my_rank; + +// // 记录映射关系 +// if (sorted_row_map) sorted_row_map[target_row] = token_idx; +// if (sorted_weights) sorted_weights[target_row] = topk_weights[tid]; + +// // 搬运 Hidden States +// const float* src_ptr = input + token_idx * hidden_dim; +// float* dst_ptr = sorted_input + target_row * hidden_dim; + +// // 简单的 float4 优化 +// int vec_size = hidden_dim / 4; +// int remainder = hidden_dim % 4; +// const float4* src_vec = (const float4*)src_ptr; +// float4* dst_vec = (float4*)dst_ptr; + +// for (int i = 0; i < vec_size; ++i) { +// dst_vec[i] = src_vec[i]; +// } +// for (int i = 0; i < remainder; ++i) { +// int idx = vec_size * 4 + i; +// dst_ptr[idx] = src_ptr[idx]; +// } +// } + +// void launch_moe_permute( +// const float* input, +// const int32_t* topk_ids, +// const float* topk_weights, +// const int32_t* expert_offsets, +// float* sorted_input, +// int32_t* sorted_row_map, +// float* sorted_weights, +// int32_t* expert_counts, +// int num_tokens, +// int top_k, +// int hidden_dim, +// int num_experts, +// cudaStream_t stream +// ) { +// int total_tasks = num_tokens * top_k; +// int block_size = 256; +// int grid_size = (total_tasks + block_size - 1) / block_size; + +// // 复用 expert_counts 作为计数器,必须清零 +// CUDA_CHECK(cudaMemsetAsync(expert_counts, 0, (num_experts + 1)* sizeof(int32_t), stream)); + +// permute_kernel<<>>( +// input, topk_ids, topk_weights, expert_offsets, expert_counts, +// sorted_input, sorted_row_map, sorted_weights, +// num_tokens, top_k, hidden_dim +// ); +// } + +// #include +// #include +// #include +// #include +// #include +// #define MAX_EXPERTS 256 + +// // 错误检查宏 +// #define CUDA_CHECK(call) \ +// do { \ +// cudaError_t error = call; \ +// if (error != cudaSuccess) { \ +// fprintf(stderr, "[KERNEL ERROR] %s failed at line %d: %s\n", #call, __LINE__, cudaGetErrorString(error)); \ +// exit(1); \ +// } \ +// } while(0) + +// // ========================================================== +// // 【新武器】GPU 数据探针 +// // ========================================================== +// __global__ void debug_inspector(int32_t* counts, int32_t* offsets, int num_experts) { +// if (threadIdx.x == 0 && blockIdx.x == 0) { +// printf("\n[GPU INSPECTOR] --- Start Analysis ---\n"); + +// // 1. 检查 Counts (前10个) +// printf("[GPU] Counts (First 10): "); +// bool counts_all_zero = true; +// for(int i=0; i 1000000000 || offsets[0] < 0) { +// printf("[GPU CRITICAL] Offsets[0] is garbage! CUB Scan failed.\n"); +// } +// if (counts_all_zero && offsets[num_experts] == 0) { +// printf("[GPU WARNING] Counts are all zero. Input indices might be wrong.\n"); +// } +// printf("[GPU INSPECTOR] --- End Analysis ---\n\n"); +// } +// } + +// // ---------------------------------------------------------- +// // Count Kernel (保持不变) +// // ---------------------------------------------------------- +// __global__ void count_kernel_sota( +// const int32_t* __restrict__ topk_ids, +// int32_t* __restrict__ expert_counts, +// int total_tasks, +// int num_experts +// ) { +// extern __shared__ int32_t smem_counts[]; +// int tid = threadIdx.x; +// int gid = blockIdx.x * blockDim.x + tid; +// if (gid == 0) { +// printf("[GPU ALIVE] Kernel started. total_tasks=%d, num_experts=%d\n", total_tasks, num_experts); +// printf("[GPU DATA] First topk_id = %d\n", topk_ids[0]); +// } +// for (int i = tid; i < num_experts; i += blockDim.x) smem_counts[i] = 0; +// __syncthreads(); + +// if (gid < total_tasks) { +// int expert_id = topk_ids[gid]; +// if (expert_id >= 0 && expert_id < num_experts) { +// unsigned int mask = __match_any_sync(__activemask(), expert_id); +// if ((tid % 32) == (__ffs(mask) - 1)) { +// atomicAdd(&smem_counts[expert_id], __popc(mask)); +// } +// } +// } +// __syncthreads(); +// for (int i = tid; i < num_experts; i += blockDim.x) { +// if (smem_counts[i] > 0) atomicAdd(&expert_counts[i], smem_counts[i]); +// // printf("这是count_kernel_sota的数字%d\n", smem_counts[i]); +// } +// } + +// // ---------------------------------------------------------- +// // Sort Launch (植入了探针) +// // ---------------------------------------------------------- +// // void launch_moe_sort( +// // const int32_t* topk_ids, +// // int32_t* expert_counts, +// // int32_t* expert_offsets, +// // int num_tokens, +// // int top_k, +// // int num_experts, +// // cudaStream_t stream +// // ) { +// // int total_tasks = num_tokens * top_k; +// // int block_size = 256; +// // int grid_size = (total_tasks + block_size - 1) / block_size; +// // printf("6\n"); +// // // 清零 (注意:这里用同步 memset 以排除异步干扰) +// // CUDA_CHECK(cudaMemset(expert_counts, 0, (num_experts + 1) * sizeof(int32_t))); + +// // count_kernel_sota<<>>( +// // topk_ids, expert_counts, total_tasks, num_experts +// // ); +// // printf("7\n"); +// // // CUB Scan +// // void* d_temp_storage = NULL; +// // size_t temp_storage_bytes = 0; + +// // cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, +// // expert_counts, expert_offsets, +// // num_experts + 1, stream); +// // printf("8\n"); +// // // 【强制同步分配】确保 Scan 内存绝对可用 +// // CUDA_CHECK(cudaMalloc(&d_temp_storage, temp_storage_bytes)); + +// // cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, +// // expert_counts, expert_offsets, +// // num_experts + 1, stream); +// // printf("9\n"); +// // CUDA_CHECK(cudaFree(d_temp_storage)); +// // } +// void launch_moe_sort( +// const int32_t* topk_ids, +// int32_t* expert_counts, +// int32_t* expert_offsets, +// int num_tokens, +// int top_k, +// int num_experts, +// cudaStream_t stream +// ) { +// int total_tasks = num_tokens * top_k; +// int block_size = 256; +// int grid_size = (total_tasks + block_size - 1) / block_size; + +// printf("6 - Preparing to launch count_kernel\n"); + +// // 1. 清零 +// CUDA_CHECK(cudaMemsetAsync(expert_counts, 0, (num_experts + 1) * sizeof(int32_t), stream)); + +// // 2. 计算共享内存大小 (关键!) +// size_t smem_size = (num_experts + 1) * sizeof(int32_t); + +// // [DEBUG] 打印启动参数,看看是不是有 0 +// printf(">>> Launch Params: Grid=%d, Block=%d, SharedMem=%zu bytes, Experts=%d\n", +// grid_size, block_size, smem_size, num_experts); + +// // 3. 启动 Kernel +// count_kernel_sota<<>>( +// topk_ids, expert_counts, total_tasks, num_experts +// ); + +// // ========================================================= +// // 【捕获启动失败】这是你没看到 printf 的真正原因 +// // ========================================================= +// cudaError_t launch_err = cudaGetLastError(); +// if (launch_err != cudaSuccess) { +// printf("❌ [FATAL] Kernel Launch Failed! Code=%d, Msg=%s\n", +// launch_err, cudaGetErrorString(launch_err)); +// // 这里不要 exit,打印出来让我们看到原因 +// } else { +// printf("✅ Kernel Launch Requested Successfully.\n"); +// } + +// // 4. CUB Scan (保持你现在的代码) +// void* d_temp_storage = NULL; +// size_t temp_storage_bytes = 0; + +// cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, +// expert_counts, expert_offsets, +// num_experts + 1, stream); + +// CUDA_CHECK(cudaMalloc(&d_temp_storage, temp_storage_bytes)); + +// cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, +// expert_counts, expert_offsets, +// num_experts + 1, stream); + +// CUDA_CHECK(cudaFree(d_temp_storage)); +// } + +// // ---------------------------------------------------------- +// // Permute Kernel (保持不变) +// // ---------------------------------------------------------- +// __global__ void permute_kernel( +// const float* __restrict__ input, +// const int32_t* __restrict__ topk_ids, +// const float* __restrict__ topk_weights, +// const int32_t* __restrict__ expert_offsets, +// int32_t* __restrict__ running_counters, +// float* __restrict__ sorted_input, +// int32_t* __restrict__ sorted_row_map, +// float* __restrict__ sorted_weights, +// int num_tokens, +// int top_k, +// int hidden_dim +// ) { +// int total_tasks = num_tokens * top_k; +// int tid = blockIdx.x * blockDim.x + threadIdx.x; + +// if (tid >= total_tasks) return; + +// int token_idx = tid / top_k; +// int expert_id = topk_ids[tid]; + +// int base_offset = expert_offsets[expert_id]; +// int my_rank = atomicAdd(&running_counters[expert_id], 1); +// int target_row = base_offset + my_rank; + +// if (sorted_row_map) sorted_row_map[target_row] = token_idx; +// if (sorted_weights) sorted_weights[target_row] = topk_weights[tid]; + +// const float* src_ptr = input + token_idx * hidden_dim; +// float* dst_ptr = sorted_input + target_row * hidden_dim; + +// for (int i = 0; i < hidden_dim; ++i) dst_ptr[i] = src_ptr[i]; +// } + +// void launch_moe_permute( +// const float* input, +// const int32_t* topk_ids, +// const float* topk_weights, +// const int32_t* expert_offsets, +// float* sorted_input, +// int32_t* sorted_row_map, +// float* sorted_weights, +// int32_t* expert_counts, +// int num_tokens, +// int top_k, +// int hidden_dim, +// int num_experts, +// cudaStream_t stream +// ) { +// int block_size = 256; +// int grid_size = (num_tokens * top_k + block_size - 1) / block_size; + +// // 清零 running_counters +// CUDA_CHECK(cudaMemset(expert_counts, 0, (num_experts + 1) * sizeof(int32_t))); + +// permute_kernel<<>>( +// input, topk_ids, topk_weights, expert_offsets, expert_counts, +// sorted_input, sorted_row_map, sorted_weights, +// num_tokens, top_k, hidden_dim +// ); +// } + + +// # import os +// # import torch +// # # 【强制黑魔法】告诉 PyTorch 我们要用新版 ABI +// # # 这行代码能救命,它防止 PyTorch 自己把 flag 改回 0 +// # torch._C._GLIBCXX_USE_CXX11_ABI = True + +// # from setuptools import setup +// # from torch.utils.cpp_extension import BuildExtension, CUDAExtension +// # import pybind11 + +// # INFINI_SRC_ROOT = "/data/users/shankgu/InfiniCore" +// # INFINI_LM_ROOT = "/data/users/shankgu/InfiniLM" +// # INFINI_LIB_DIR = "/data/users/shankgu/InfiniCore/build/linux/x86_64/release" + +// # # 你的库列表 (保持你之前的配置) +// # libs = [ +// # # 如果你的 gumoe.cpp 继承了 Module,你需要链接 utils 库 +// # os.path.join(INFINI_LIB_DIR, 'libinfini-utils.a'), +// # os.path.join(INFINI_LIB_DIR, 'libinfiniop-nvidia.a'), +// # os.path.join(INFINI_LIB_DIR, 'libinfiniccl-nvidia.a'), +// # os.path.join(INFINI_LIB_DIR, 'libinfinirt-nvidia.a') +// # ] + +// # setup( +// # name='gu_moe_ops', +// # version='0.1.0', +// # ext_modules=[ +// # CUDAExtension( +// # name='gu_moe_ops', +// # sources=[ +// # 'pybind_gumoe.cc', +// # 'src/gumoe.cpp', +// # 'src/gu_mul.cc', +// # 'src/gu_topk_softmax.cc', +// # 'src/nvidia_kernels/gu_reduce.cu', +// # 'src/nvidia_kernels/gu_sort.cu', +// # ], +// # include_dirs=[ +// # pybind11.get_include(), +// # os.path.join(INFINI_SRC_ROOT, 'include'), +// # os.path.join(INFINI_LM_ROOT, 'src'), +// # 'src' +// # ], +// # extra_objects=libs, + +// # extra_compile_args={ +// # # 【唯一关键点】必须设为 1,解决 ...ESs 报错 +// # 'cxx': ['-O3', '-std=c++17', '-D_GLIBCXX_USE_CXX11_ABI=1'], +// # 'nvcc': ['-O3'] +// # } +// # ) +// # ], +// # cmdclass={ +// # 'build_ext': BuildExtension +// # } +// # ) + +// #include +// #include +// #include +// #include +// #include +// #include + +// #include "src/gu_moe.h" +// #include "infinicore/tensor.hpp" +// #include "infinicore/device.hpp" + +// namespace py = pybind11; +// using namespace infinicore; + +// // int_to_dtype 略 (保持不变)... +// infinicore::DataType int_to_dtype(int id) { +// switch (id) { +// case 0: return infinicore::DataType::F32; +// case 1: return infinicore::DataType::BF16; +// case 2: return infinicore::DataType::I32; +// case 3: return infinicore::DataType::F16; +// default: throw std::runtime_error("Unknown dtype id: " + std::to_string(id)); +// } +// } + +// class PyGuMoeWrapper { +// public: +// std::shared_ptr block; + +// PyGuMoeWrapper(int num_experts, int hidden_dim, int intermediate_dim, +// int dtype_id, int device_id) { +// Device device(Device::Type::NVIDIA, device_id); +// block = std::make_shared( +// num_experts, hidden_dim, intermediate_dim, 2, true, +// int_to_dtype(dtype_id), device +// ); +// } + +// infinicore::nn::Parameter object_to_tensor(py::object tensor_obj) { +// uint64_t ptr_val = tensor_obj.attr("ptr").cast(); +// void* raw_ptr = reinterpret_cast(ptr_val); +// std::vector shape_vec = tensor_obj.attr("shape").cast>(); +// infinicore::Shape shape; +// for(auto s : shape_vec) shape.push_back(s); +// int dtype_id = tensor_obj.attr("dtype_id").cast(); +// int dev_id = tensor_obj.attr("device_id").cast(); +// infinicore::Device dev(infinicore::Device::Type::NVIDIA, dev_id); +// return infinicore::Tensor::from_blob(raw_ptr, shape, int_to_dtype(dtype_id), dev); +// } + +// // ✅✅✅ 【关键】这里只有 2 个参数! +// void forward(py::object input_obj, py::object output_obj) { +// auto input = object_to_tensor(input_obj); + +// // 调用 C++ Block (单参数 forward) +// auto internal_result = block->forward(input); + +// auto output_buffer = object_to_tensor(output_obj); +// size_t bytes = internal_result->numel() * 4; +// cudaMemcpy(output_buffer->data(), internal_result->data(), bytes, cudaMemcpyDeviceToDevice); +// } + +// // ✅ set_weights 接收 3 个参数:GateUp, Down, RouterWeight +// void set_weights(py::object gate_up_obj, py::object down_obj, py::object router_w_obj) { +// auto gate_up = object_to_tensor(gate_up_obj); +// auto down = object_to_tensor(down_obj); +// auto router_w = object_to_tensor(router_w_obj); + +// block->set_weights(gate_up, down, router_w); +// std::cout << "[C++] All weights set (Experts + Router)." << std::endl; +// } +// }; + +// PYBIND11_MODULE(gu_moe_ops, m) { +// py::class_(m, "GuMoeBlock") +// .def(py::init()) +// .def("forward", &PyGuMoeWrapper::forward) +// .def("set_weights", &PyGuMoeWrapper::set_weights); +// } \ No newline at end of file diff --git a/moe_gu_ops/src/nvidia_kernels/cuda_utils.h b/moe_gu_ops/src/nvidia_kernels/cuda_utils.h new file mode 100644 index 000000000..cce306eee --- /dev/null +++ b/moe_gu_ops/src/nvidia_kernels/cuda_utils.h @@ -0,0 +1,38 @@ +#ifndef CUDA_UTILS_H +#define CUDA_UTILS_H + +#include +#include +#include +#include + +// ========================================================= +// 标准 CUDA 错误检查宏 +// ========================================================= +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + fprintf(stderr, "CUDA Error at %s:%d code=%d(%s)\n", \ + __FILE__, __LINE__, err, cudaGetErrorString(err)); \ + exit(EXIT_FAILURE); \ + } \ + } while (0) + +// ========================================================= +// 辅助宏:计算 Grid 大小 +// ========================================================= +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) + +// ========================================================= +// 辅助函数:Warp 级归约 (Reduce) +// 很多 Reduce Kernel 会用到这个 +// ========================================================= +template +__inline__ __device__ T warpReduceSum(T val) { + for (int offset = warpSize / 2; offset > 0; offset /= 2) + val += __shfl_down_sync(0xffffffff, val, offset); + return val; +} + +#endif // CUDA_UTILS_H \ No newline at end of file diff --git a/moe_gu_ops/src/nvidia_kernels/cutlass b/moe_gu_ops/src/nvidia_kernels/cutlass new file mode 160000 index 000000000..712759206 --- /dev/null +++ b/moe_gu_ops/src/nvidia_kernels/cutlass @@ -0,0 +1 @@ +Subproject commit 7127592069c2fe01b041e174ba4345ef9b279671 diff --git a/moe_gu_ops/src/nvidia_kernels/gu_gemm_grouped.cu b/moe_gu_ops/src/nvidia_kernels/gu_gemm_grouped.cu new file mode 100644 index 000000000..f6766da40 --- /dev/null +++ b/moe_gu_ops/src/nvidia_kernels/gu_gemm_grouped.cu @@ -0,0 +1,256 @@ +#include "cuda_utils.h" +#include "nvidia_kernels_moe.h" +#include + +// 引入 CUTLASS +#include "cutlass/cutlass.h" +#include "cutlass/include/cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/include/cutlass/gemm/kernel/default_gemm_grouped.h" + +// ======================================================================== +// 1. 定义 CUTLASS GEMM 类型 +// ======================================================================== + +// 设定精度 +using ElementInputA = float; +using ElementInputB = float; +using ElementOutput = float; +using ElementAccumulator = float; + +// 设定布局 (关键!) +// A (Input): RowMajor [M, K] +using LayoutA = cutlass::layout::RowMajor; +// B (Weight): ColumnMajor [K, N] -> 物理上对应 RowMajor [N, K] +using LayoutB = cutlass::layout::ColumnMajor; +// C (Output): RowMajor [M, N] +using LayoutC = cutlass::layout::RowMajor; + +// 定义 Grouped GEMM 算子 +// 架构: Sm80 (Ampere A100/3090), 如果是 V100 用 Sm70 +using GemmGrouped = cutlass::gemm::device::GemmGrouped< + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<8, 8, 4>, + ElementInputA, LayoutA, + ElementInputB, LayoutB, + ElementOutput, LayoutC, + ElementAccumulator, + cutlass::gemm::GemmGroupedIteratorAlgorithm::kOffsetBased, + cutlass::arch::OpClassSimt, // FP32 使用 SIMT, 如果是 BF16/FP16 使用 OpClassTensorOp + cutlass::arch::Sm80 +>; + +// ======================================================================== +// 2. 参数准备 Kernel (Meta-Kernel) +// ======================================================================== +// 这个 Kernel 负责在 GPU 上生成 CUTLASS 需要的参数结构体 +__global__ void prepare_gemm_args( + const int32_t* __restrict__ offsets, // [Experts + 1] + const float* __restrict__ input_base, // 连续的 sorted_input + const float* __restrict__ weight_base, // 连续的权重 [E, N, K] + float* __restrict__ output_base, // 连续的 output + cutlass::gemm::GemmCoord* problem_sizes, // 输出: [E] 尺寸 + const float** ptr_A, // 输出: [E] 指针 + const float** ptr_B, // 输出: [E] 指针 + float** ptr_C, // 输出: [E] 指针 + float** ptr_D, // 输出: [E] 指针 + int num_experts, + int n, int k, // N, K 是固定的 + int lda, int ldb, int ldc // Strides +) { + int idx = threadIdx.x + blockIdx.x * blockDim.x; + if (idx >= num_experts) return; + + // 1. 算出当前专家的 M (Token 数) + int start_row = offsets[idx]; + int end_row = offsets[idx + 1]; + int m = end_row - start_row; + + // 2. 填写尺寸: M, N, K + problem_sizes[idx] = cutlass::gemm::GemmCoord(m, n, k); + + if (m > 0) { + // 3. 计算指针位置 + // A: Input [start_row, 0] + ptr_A[idx] = input_base + start_row * lda; + + // B: Weight [idx, 0, 0] + // 物理上 Weight 是 [E, N, K],每个专家占 N*K + ptr_B[idx] = weight_base + idx * (long long)n * k; + + // C/D: Output [start_row, 0] + ptr_C[idx] = output_base + start_row * ldc; + ptr_D[idx] = ptr_C[idx]; + } else { + ptr_A[idx] = nullptr; + ptr_B[idx] = nullptr; + ptr_C[idx] = nullptr; + ptr_D[idx] = nullptr; + } +} + +// ======================================================================== +// 3. 辅助 Kernel: Activation (SiLU * Mul) +// ======================================================================== +__global__ void silu_and_mul_kernel( + float* __restrict__ gate_up_output, // [Total_M, 2 * Inter] + int total_elements, // Total_M * Inter + int inter_dim +) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + if (tid >= total_elements) return; + + int row = tid / inter_dim; + int col = tid % inter_dim; + + // 内存布局: [Gate | Up] (RowMajor) + // Gate 在前半部分,Up 在后半部分 + long long gate_idx = (long long)row * (2 * inter_dim) + col; + long long up_idx = gate_idx + inter_dim; + + float gate_val = gate_up_output[gate_idx]; + float up_val = gate_up_output[up_idx]; + + // SiLU Calculation: x / (1 + exp(-x)) + float silu_val = gate_val / (1.0f + expf(-gate_val)); + + // In-place update: 把结果写回 Gate 的位置 + gate_up_output[gate_idx] = silu_val * up_val; +} + +// ======================================================================== +// 4. Host Launcher +// ======================================================================== + +void launch_moe_gemm_ffn( + const float* sorted_input, + const int32_t* expert_offsets, + float* sorted_output, + const float* gate_up_proj_base, // [Experts, 2*Inter, Hidden] + const float* down_proj_base, // [Experts, Hidden, Inter] + int num_experts, + int hidden_dim, // K for GEMM1, N for GEMM2 + int inter_dim, // N/2 for GEMM1, K for GEMM2 + cudaStream_t stream +) { + // ------------------------------------------------------------- + // Phase 1: 准备 CUTLASS 参数的显存 + // ------------------------------------------------------------- + // Grouped GEMM 需要传入指针数组。 + // 计算 Workspace 大小 + size_t size_coord = num_experts * sizeof(cutlass::gemm::GemmCoord); + size_t size_ptr = num_experts * sizeof(void*); + // 我们需要两套参数:一套给 GateUp GEMM,一套给 Down GEMM + // 为了简单,我们复用同一块显存,算完第一个再算第二个。 + size_t workspace_bytes = size_coord + 4 * size_ptr; + + // 【注意】这里先 malloc,工业级应从外部传入 Workspace + char* d_args_buffer; + CUDA_CHECK(cudaMallocAsync(&d_args_buffer, workspace_bytes, stream)); + + // 指针切分 + cutlass::gemm::GemmCoord* d_problem_sizes = (cutlass::gemm::GemmCoord*)(d_args_buffer); + const float** d_ptr_A = (const float**)(d_args_buffer + size_coord); + const float** d_ptr_B = (const float**)(d_args_buffer + size_coord + size_ptr); + float** d_ptr_C = (float**)(d_args_buffer + size_coord + 2 * size_ptr); + float** d_ptr_D = (float**)(d_args_buffer + size_coord + 3 * size_ptr); + + // ------------------------------------------------------------- + // Phase 2: GEMM 1 (Input * GateUp^T -> Middle) + // ------------------------------------------------------------- + // Input: [M, Hidden] + // Weight: [2*Inter, Hidden] (物理上) -> 逻辑上看作 [Hidden, 2*Inter] ColumnMajor + // Output: [M, 2*Inter] + + // 我们需要一个 Middle Buffer 存 Gate+Up 的结果 + // 获取总 Token 数 (需要从 CPU 或者拷贝 offsets 的最后一个值,这里为了简单假设已知或同步获取) + int total_tokens; + CUDA_CHECK(cudaMemcpyAsync(&total_tokens, expert_offsets + num_experts, sizeof(int), cudaMemcpyDeviceToHost, stream)); + CUDA_CHECK(cudaStreamSynchronize(stream)); // 等一下拿到 total_tokens + + float* d_middle; + CUDA_CHECK(cudaMallocAsync(&d_middle, total_tokens * 2 * inter_dim * sizeof(float), stream)); + + // 2.1 填充参数 + prepare_gemm_args<<< (num_experts+255)/256, 256, 0, stream >>>( + expert_offsets, + sorted_input, + gate_up_proj_base, + d_middle, + d_problem_sizes, d_ptr_A, d_ptr_B, d_ptr_C, d_ptr_D, + num_experts, + 2 * inter_dim, // N: 输出维度 + hidden_dim, // K: 输入维度 + hidden_dim, // lda (Input Row Stride) + hidden_dim, // ldb (Weight Column Stride = 物理上的 Row Stride) <--- 魔法在这里 + 2 * inter_dim // ldc (Output Row Stride) + ); + + // 2.2 运行 CUTLASS + GemmGrouped gemm; + typename GemmGrouped::Arguments args_1; + args_1.problem_sizes = d_problem_sizes; + args_1.count = num_experts; + args_1.threadblock_count = 0; + args_1.alpha = 1.0f; args_1.beta = 0.0f; + args_1.ptr_A = d_ptr_A; args_1.ptr_B = d_ptr_B; args_1.ptr_C = d_ptr_C; args_1.ptr_D = d_ptr_D; + // LDA/LDB/LDC 在 kernel 中算好了指针,这里设为 0 或默认即可,因为是指针模式 + + // 初始化并运行 + size_t gemm_ws_size = gemm.get_workspace_size(args_1); + void* gemm_ws = nullptr; + if (gemm_ws_size > 0) CUDA_CHECK(cudaMallocAsync(&gemm_ws, gemm_ws_size, stream)); + + CUDA_CHECK((cudaError_t)gemm.initialize(args_1, gemm_ws)); + CUDA_CHECK((cudaError_t)gemm.run(stream)); + + // ------------------------------------------------------------- + // Phase 3: Activation (SiLU) + // ------------------------------------------------------------- + int total_act_elements = total_tokens * inter_dim; + silu_and_mul_kernel<<< (total_act_elements+255)/256, 256, 0, stream >>>( + d_middle, total_act_elements, inter_dim + ); + + // ------------------------------------------------------------- + // Phase 4: GEMM 2 (Middle * Down^T -> Output) + // ------------------------------------------------------------- + // Input (Middle): [M, Inter] (前半部分存了结果) + // Weight: [Hidden, Inter] (物理上) -> 逻辑 [Inter, Hidden] ColumnMajor + // Output: [M, Hidden] + + // 4.1 填充参数 (复用 d_args_buffer) + prepare_gemm_args<<< (num_experts+255)/256, 256, 0, stream >>>( + expert_offsets, + d_middle, // Input is now middle buffer + down_proj_base, + sorted_output, // Final Output + d_problem_sizes, d_ptr_A, d_ptr_B, d_ptr_C, d_ptr_D, + num_experts, + hidden_dim, // N: 输出维度 + inter_dim, // K: 输入维度 + 2 * inter_dim, // lda (Input Stride, 注意中间有 gap,stride 是 2*inter) + inter_dim, // ldb (Weight Stride) + hidden_dim // ldc (Output Stride) + ); + + // 4.2 运行 CUTLASS + typename GemmGrouped::Arguments args_2 = args_1; // 复用配置,更新指针 + // 指针已经在 GPU 上更新了,所以 args 结构体里的指针不需要变,只需要重新 initialize + // Wait... args结构体存的是 host 指针 d_problem_sizes。 + // 但是 CUTLASS 内部可能缓存了一些信息吗?最好重新构造 args。 + + args_2.problem_sizes = d_problem_sizes; + args_2.count = num_experts; + args_2.ptr_A = d_ptr_A; args_2.ptr_B = d_ptr_B; args_2.ptr_C = d_ptr_C; args_2.ptr_D = d_ptr_D; + + CUDA_CHECK((cudaError_t)gemm.initialize(args_2, gemm_ws)); + CUDA_CHECK((cudaError_t)gemm.run(stream)); + + // ------------------------------------------------------------- + // Cleanup + // ------------------------------------------------------------- + CUDA_CHECK(cudaFreeAsync(d_args_buffer, stream)); + CUDA_CHECK(cudaFreeAsync(d_middle, stream)); + if (gemm_ws) CUDA_CHECK(cudaFreeAsync(gemm_ws, stream)); +} \ No newline at end of file diff --git a/moe_gu_ops/src/nvidia_kernels/gu_reduce.cu b/moe_gu_ops/src/nvidia_kernels/gu_reduce.cu new file mode 100644 index 000000000..8b64973a7 --- /dev/null +++ b/moe_gu_ops/src/nvidia_kernels/gu_reduce.cu @@ -0,0 +1,87 @@ +#include "nvidia_kernels_moe.h" +#include "cuda_utils.h" + +// ======================================================================== +// Kernel: Reduce (加权还原) - float4 向量化优化版 +// ======================================================================== +__global__ void reduce_kernel_opt( + const float* __restrict__ sorted_output, // [Total_Tasks, H] + const int32_t* __restrict__ sorted_row_map, // [Total_Tasks] + const float* __restrict__ sorted_weights, // [Total_Tasks] + float* __restrict__ final_output, // [Num_Tokens, H] + int total_tasks, // N * K + int hidden_dim +) { + // 策略:每个线程处理 4 个元素 (128 bit) + // 这样的 grid 维度计算需要除以 4 + int vec_dim = hidden_dim / 4; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + // 1. 处理向量化部分 (4的倍数) + if (tid < total_tasks * vec_dim) { + int row_idx = tid / vec_dim; // 第几行 + int vec_idx = tid % vec_dim; // 第几组 float4 + int col_idx = vec_idx * 4; // 实际列号 + + // 查户口 + int original_token_idx = sorted_row_map[row_idx]; + float weight = sorted_weights[row_idx]; + + // 向量化读取 (Load 128-bit) + // 强转指针为 float4* + const float4* src_vec_ptr = (const float4*)sorted_output; + float4 val_vec = src_vec_ptr[tid]; // 直接读 tid 位置的 float4 + + // 目标基地址 + float* dst_base = final_output + original_token_idx * hidden_dim + col_idx; + + // 原子累加 (atomicAdd 不支持 float4,必须拆开) + // 但读取指令减少了,依然有加速 + atomicAdd(dst_base + 0, val_vec.x * weight); + atomicAdd(dst_base + 1, val_vec.y * weight); + atomicAdd(dst_base + 2, val_vec.z * weight); + atomicAdd(dst_base + 3, val_vec.w * weight); + } + + // 2. 处理尾巴 (Remainder) + int remainder = hidden_dim % 4; + if (remainder > 0) { + int row_idx = tid / vec_dim; // 近似映射 + } +} + +// ======================================================================== +// Host Launcher +// ======================================================================== +void launch_moe_reduce( + const float* sorted_output, + const int32_t* sorted_row_map, + const float* sorted_weights, + float* final_output, + int num_tokens, + int top_k, + int hidden_dim, + cudaStream_t stream +) { + int total_tasks = num_tokens * top_k; + + // 优先使用向量化版本 + if (hidden_dim % 4 == 0) { + int vec_dim = hidden_dim / 4; + long long total_threads = (long long)total_tasks * vec_dim; + int block_size = 256; + int grid_size = (total_threads + block_size - 1) / block_size; + + reduce_kernel_opt<<>>( + sorted_output, sorted_row_map, sorted_weights, final_output, + total_tasks, hidden_dim + ); + } else { + long long total_elements = (long long)total_tasks * hidden_dim; + int block_size = 256; + int grid_size = (total_elements + block_size - 1) / block_size; + + // reduce_kernel_scalar<<<...>>>(...); // 你之前的那个函数 + printf("Warning: Hidden dim not divisible by 4, running slow path.\n"); + } +} \ No newline at end of file diff --git a/moe_gu_ops/src/nvidia_kernels/gu_sort.cu b/moe_gu_ops/src/nvidia_kernels/gu_sort.cu new file mode 100644 index 000000000..46e1936df --- /dev/null +++ b/moe_gu_ops/src/nvidia_kernels/gu_sort.cu @@ -0,0 +1,236 @@ +#include +#include +#include +#include + +#define MAX_EXPERTS 256 + +// 错误检查宏 +#define CUDA_CHECK(call) \ +do { \ + cudaError_t error = call; \ + if (error != cudaSuccess) { \ + fprintf(stderr, "[KERNEL ERROR] %s failed at line %d: %s\n", #call, __LINE__, cudaGetErrorString(error)); \ + exit(1); \ + } \ +} while(0) + +// ============================================================= +// 1. Count Kernel (纯净版,带越界保护) +// ============================================================= +__global__ void count_kernel_sota( + const int32_t* __restrict__ topk_ids, + int32_t* __restrict__ expert_counts, + int total_tasks, + int num_experts +) { + extern __shared__ int32_t smem_counts[]; + + int tid = threadIdx.x; + int bid = blockIdx.x; + int gid = bid * blockDim.x + tid; + + // 清空共享内存 + for (int i = tid; i < num_experts; i += blockDim.x) { + if (i < num_experts) smem_counts[i] = 0; + } + __syncthreads(); + + // 统计 (带边界检查) + if (gid < total_tasks) { + int expert_id = topk_ids[gid]; + // 【关键保护】非法ID直接忽略,防止 crash + if (expert_id >= 0 && expert_id < num_experts) { + unsigned int mask = __match_any_sync(__activemask(), expert_id); + int leader = __ffs(mask) - 1; + int lane_id = tid % 32; + if (lane_id == leader) { + int agg_count = __popc(mask); + atomicAdd(&smem_counts[expert_id], agg_count); + } + } + } + __syncthreads(); + + // 写回全局内存 + for (int i = tid; i < num_experts; i += blockDim.x) { + int count = smem_counts[i]; + if (count > 0) { + atomicAdd(&expert_counts[i], count); + } + } +} + +// ============================================================= +// 2. Sort Launch (纯净版) +// ============================================================= +void launch_moe_sort( + const int32_t* topk_ids, + int32_t* expert_counts, + int32_t* expert_offsets, + int num_tokens, + int top_k, + int num_experts, + cudaStream_t stream +) { + int total_tasks = num_tokens * top_k; + int block_size = 256; + int grid_size = (total_tasks + block_size - 1) / block_size; + + // 1. 清零 Counts (必须覆盖 num_experts + 1) + CUDA_CHECK(cudaMemsetAsync(expert_counts, 0, (num_experts + 1) * sizeof(int32_t), stream)); + + // 2. 计算共享内存大小 (这一步绝不能省) + size_t smem_size = (num_experts + 1) * sizeof(int32_t); + + // 3. 启动 Kernel + count_kernel_sota<<>>( + topk_ids, expert_counts, total_tasks, num_experts + ); + + // 4. CUB Scan (前缀和) + void* d_temp_storage = NULL; + size_t temp_storage_bytes = 0; + + // 查询所需显存 + cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, + expert_counts, expert_offsets, + num_experts + 1, stream); + + // 分配临时显存 (使用同步 malloc 确保稳定) + CUDA_CHECK(cudaMalloc(&d_temp_storage, temp_storage_bytes)); + + // 执行 Scan + cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, + expert_counts, expert_offsets, + num_experts + 1, stream); + + // 释放 + CUDA_CHECK(cudaFree(d_temp_storage)); +} + +// ============================================================= +// 3. Permute Kernel (纯净版,带越界保护) +// ============================================================= +__global__ void permute_kernel( + const float* __restrict__ input, + const int32_t* __restrict__ topk_ids, + const float* __restrict__ topk_weights, + const int32_t* __restrict__ expert_offsets, + int32_t* __restrict__ running_counters, + float* __restrict__ sorted_input, + int32_t* __restrict__ sorted_row_map, + float* __restrict__ sorted_weights, + int num_tokens, + int top_k, + int hidden_dim, + int num_experts +) { + int total_tasks = num_tokens * top_k; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid >= total_tasks) return; + + int token_idx = tid / top_k; + int expert_id = topk_ids[tid]; + + // 【关键保护】防止非法ID导致读取 offsets 越界 + if (expert_id < 0 || expert_id >= num_experts) return; + + int base_offset = expert_offsets[expert_id]; + int my_rank = atomicAdd(&running_counters[expert_id], 1); + int target_row = base_offset + my_rank; + + if (sorted_row_map) sorted_row_map[target_row] = token_idx; + if (sorted_weights) sorted_weights[target_row] = topk_weights[tid]; + + const float* src_ptr = input + token_idx * hidden_dim; + float* dst_ptr = sorted_input + target_row * hidden_dim; + + // float4 优化拷贝 + const float4* src_vec = (const float4*)src_ptr; + float4* dst_vec = (float4*)dst_ptr; + int vec_len = hidden_dim / 4; + + for (int i = 0; i < vec_len; ++i) dst_vec[i] = src_vec[i]; + + // 处理剩余部分 + for (int i = vec_len * 4; i < hidden_dim; ++i) dst_ptr[i] = src_ptr[i]; +} + +void launch_moe_permute( + const float* input, + const int32_t* topk_ids, + const float* topk_weights, + const int32_t* expert_offsets, + float* sorted_input, + int32_t* sorted_row_map, + float* sorted_weights, + int32_t* expert_counts, + int num_tokens, + int top_k, + int hidden_dim, + int num_experts, + cudaStream_t stream +) { + int block_size = 256; + int grid_size = (num_tokens * top_k + block_size - 1) / block_size; + + // 复用 expert_counts 作为 running_counters,必须清零 + CUDA_CHECK(cudaMemsetAsync(expert_counts, 0, (num_experts + 1) * sizeof(int32_t), stream)); + + permute_kernel<<>>( + input, topk_ids, topk_weights, expert_offsets, expert_counts, + sorted_input, sorted_row_map, sorted_weights, + num_tokens, top_k, hidden_dim, num_experts + ); +} + +// ============================================================= +// 4. Reduce Kernel (纯净版) +// ============================================================= +__global__ void reduce_kernel( + const float* __restrict__ sorted_output, + const int32_t* __restrict__ sorted_row_map, + const float* __restrict__ sorted_weights, + float* __restrict__ final_output, + int total_tasks, + int hidden_dim +) { + int tid = blockIdx.x * blockDim.x + threadIdx.x; + int total_elements = total_tasks * hidden_dim; + + if (tid >= total_elements) return; + + int row = tid / hidden_dim; + int col = tid % hidden_dim; + + int original_token_idx = sorted_row_map[row]; + float weight = sorted_weights[row]; + float val = sorted_output[tid]; + + // 加权求和写回原位置 + float* target_ptr = final_output + original_token_idx * hidden_dim + col; + atomicAdd(target_ptr, val * weight); +} + +void launch_moe_reduce( + float* sorted_output, + int32_t* sorted_row_map, + float* sorted_weights, + float* final_output, + int num_tokens, + int top_k, + int hidden_dim, + cudaStream_t stream +) { + int total_tasks = num_tokens * top_k; + int total_elements = total_tasks * hidden_dim; + int block_size = 256; + int grid_size = (total_elements + block_size - 1) / block_size; + + reduce_kernel<<>>( + sorted_output, sorted_row_map, sorted_weights, final_output, + total_tasks, hidden_dim + ); +} \ No newline at end of file diff --git a/moe_gu_ops/src/nvidia_kernels/nvidia_kernels_moe.h b/moe_gu_ops/src/nvidia_kernels/nvidia_kernels_moe.h new file mode 100644 index 000000000..1462e281a --- /dev/null +++ b/moe_gu_ops/src/nvidia_kernels/nvidia_kernels_moe.h @@ -0,0 +1,42 @@ +#pragma once +#include +#include + +// 算子 1: 排序与偏移计算 (Count + CUB Scan) +void launch_moe_sort( + const int32_t* topk_ids, + int32_t* expert_counts, + int32_t* expert_offsets, + int num_tokens, + int top_k, + int num_experts, + cudaStream_t stream +); + +// 算子 2: 数据搬运 (Permutation) +void launch_moe_permute( + const float* input, + const int32_t* topk_ids, + const float* topk_weights, // [Input] 原始权重 + const int32_t* expert_offsets, + float* sorted_input, + int32_t* sorted_row_map, + float* sorted_weights, // [Output] 排序后的权重 + int32_t* expert_counts, + int num_tokens, + int top_k, + int hidden_dim, + int num_experts, + cudaStream_t stream +); + +void launch_moe_reduce( + const float* sorted_output, + const int32_t* sorted_row_map, + const float* sorted_weights, + float* final_output, + int num_tokens, + int top_k, + int hidden_dim, + cudaStream_t stream +); \ No newline at end of file diff --git a/moe_gu_ops/src/nvidia_kernels/test/test_moe_sort b/moe_gu_ops/src/nvidia_kernels/test/test_moe_sort new file mode 100755 index 000000000..a29eba786 Binary files /dev/null and b/moe_gu_ops/src/nvidia_kernels/test/test_moe_sort differ diff --git a/moe_gu_ops/src/nvidia_kernels/test/test_sort.cu b/moe_gu_ops/src/nvidia_kernels/test/test_sort.cu new file mode 100644 index 000000000..3965077ad --- /dev/null +++ b/moe_gu_ops/src/nvidia_kernels/test/test_sort.cu @@ -0,0 +1,333 @@ +#include +#include +#include +#include +#include +#include +#include +#include + +#define MAX_EXPERTS 256 + +// ================================================================================= +// 1. 辅助宏与工具 +// ================================================================================= +#define CUDA_CHECK(call) \ + do { \ + cudaError_t err = call; \ + if (err != cudaSuccess) { \ + fprintf(stderr, "CUDA Error at %s:%d - %s\n", __FILE__, __LINE__, cudaGetErrorString(err)); \ + exit(1); \ + } \ + } while (0) + +// ================================================================================= +// 2. 你的 Kernel 实现 (直接复制粘贴你的代码) +// ================================================================================= + +__global__ void count_kernel_sota( + const int32_t* __restrict__ topk_ids, + int32_t* __restrict__ expert_counts, + int total_tasks, + int num_experts +) { + extern __shared__ int32_t smem_counts[]; + + int tid = threadIdx.x; + int bid = blockIdx.x; + int gid = bid * blockDim.x + tid; + + for (int i = tid; i < num_experts; i += blockDim.x) { + smem_counts[i] = 0; + } + __syncthreads(); + + if (gid < total_tasks) { + int expert_id = topk_ids[gid]; + // 简单的 Warp 聚合逻辑验证 + unsigned int active_mask = __activemask(); + unsigned int mask = __match_any_sync(active_mask, expert_id); + int leader = __ffs(mask) - 1; + int lane_id = tid % 32; + + if (lane_id == leader) { + int agg_count = __popc(mask); + atomicAdd(&smem_counts[expert_id], agg_count); + } + } + + __syncthreads(); + + for (int i = tid; i < num_experts; i += blockDim.x) { + int count = smem_counts[i]; + if (count > 0) { + atomicAdd(&expert_counts[i], count); + } + } +} + +void launch_moe_sort( + const int32_t* topk_ids, + int32_t* expert_counts, + int32_t* expert_offsets, + int num_tokens, + int top_k, + int num_experts, + cudaStream_t stream +) { + int total_tasks = num_tokens * top_k; + int block_size = 256; + int grid_size = (total_tasks + block_size - 1) / block_size; + + CUDA_CHECK(cudaMemsetAsync(expert_counts, 0, num_experts * sizeof(int32_t), stream)); + + count_kernel_sota<<>>( + topk_ids, expert_counts, total_tasks, num_experts + ); + + void* d_temp_storage = NULL; + size_t temp_storage_bytes = 0; + + cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, expert_counts, expert_offsets, num_experts + 1, stream); + CUDA_CHECK(cudaMallocAsync(&d_temp_storage, temp_storage_bytes, stream)); + cub::DeviceScan::ExclusiveSum(d_temp_storage, temp_storage_bytes, expert_counts, expert_offsets, num_experts + 1, stream); + CUDA_CHECK(cudaFreeAsync(d_temp_storage, stream)); +} + +__global__ void permute_kernel( + const float* __restrict__ input, + const int32_t* __restrict__ topk_ids, + const int32_t* __restrict__ expert_offsets, + int32_t* __restrict__ running_counters, + float* __restrict__ sorted_input, + int32_t* __restrict__ sorted_row_map, + int num_tokens, + int top_k, + int hidden_dim +) { + int total_tasks = num_tokens * top_k; + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid >= total_tasks) return; + + int token_idx = tid / top_k; + int expert_id = topk_ids[tid]; + + int base_offset = expert_offsets[expert_id]; + int my_rank = atomicAdd(&running_counters[expert_id], 1); + int target_row = base_offset + my_rank; + + sorted_row_map[target_row] = token_idx; + + const float* src_ptr = input + token_idx * hidden_dim; + float* dst_ptr = sorted_input + target_row * hidden_dim; + + int vec_size = hidden_dim / 4; + int remainder = hidden_dim % 4; + + const float4* src_vec = (const float4*)src_ptr; + float4* dst_vec = (float4*)dst_ptr; + + for (int i = 0; i < vec_size; ++i) { + dst_vec[i] = src_vec[i]; + } + for (int i = 0; i < remainder; ++i) { + int idx = vec_size * 4 + i; + dst_ptr[idx] = src_ptr[idx]; + } +} + +void launch_moe_permute( + const float* input, + const int32_t* topk_ids, + const int32_t* expert_offsets, + float* sorted_input, + int32_t* sorted_row_map, + int32_t* expert_counts, + int num_tokens, + int top_k, + int hidden_dim, + int num_experts, + cudaStream_t stream +) { + int total_tasks = num_tokens * top_k; + int block_size = 256; + int grid_size = (total_tasks + block_size - 1) / block_size; + + CUDA_CHECK(cudaMemsetAsync(expert_counts, 0, num_experts * sizeof(int32_t), stream)); + + permute_kernel<<>>( + input, topk_ids, expert_offsets, expert_counts, + sorted_input, sorted_row_map, num_tokens, top_k, hidden_dim + ); +} + +// ================================================================================= +// 3. CPU 基准验证逻辑 (Ground Truth) +// ================================================================================= +void verify_results( + const std::vector& h_input, + const std::vector& h_topk_ids, + const std::vector& gpu_offsets, + const std::vector& gpu_sorted_input, + const std::vector& gpu_row_map, + int num_tokens, int top_k, int hidden_dim, int num_experts +) { + int total_tasks = num_tokens * top_k; + + // 1. CPU Count + std::vector cpu_counts(num_experts, 0); + for (int i = 0; i < total_tasks; ++i) { + cpu_counts[h_topk_ids[i]]++; + } + + // 2. CPU Offset + std::vector cpu_offsets(num_experts + 1, 0); + for (int i = 0; i < num_experts; ++i) { + cpu_offsets[i + 1] = cpu_offsets[i] + cpu_counts[i]; + } + + // 验证 Offsets + bool offset_ok = true; + for (int i = 0; i <= num_experts; ++i) { + if (cpu_offsets[i] != gpu_offsets[i]) { + std::cout << "❌ Offset Mismatch at Expert " << i + << ": CPU=" << cpu_offsets[i] << ", GPU=" << gpu_offsets[i] << std::endl; + offset_ok = false; + } + } + if (offset_ok) std::cout << "✅ Offsets Verification Passed!" << std::endl; + + // 3. CPU Permute + std::vector cpu_sorted_input(total_tasks * hidden_dim, 0.0f); + std::vector cpu_row_map(total_tasks, 0); + std::vector running_counters(num_experts, 0); + + for (int t = 0; t < num_tokens; ++t) { + for (int k = 0; k < top_k; ++k) { + int task_idx = t * top_k + k; + int expert_id = h_topk_ids[task_idx]; + + int base = cpu_offsets[expert_id]; + int rank = running_counters[expert_id]++; + int target_row = base + rank; + + // 记录 Row Map + cpu_row_map[target_row] = t; + + // 搬运数据 + for (int h = 0; h < hidden_dim; ++h) { + cpu_sorted_input[target_row * hidden_dim + h] = h_input[t * hidden_dim + h]; + } + } + } + + // 验证 Row Map (注意:多线程下的 row map 顺序对于同一个 Expert 内部可能是不确定的, + // 但是在这个测试用例中,我们单线程生成数据,GPU 也是顺序 atomic,通常是一致的。 + // 如果不一致,我们要检查是否属于同一个 Expert。 + // 严格来说,只需验证 gpu_sorted_input 中的数据是否等于 input[gpu_row_map[i]] 且 expert id 匹配) + + // 我们采用宽松验证:验证 gpu_sorted_input 的值是否正确 + bool data_ok = true; + for (int i = 0; i < total_tasks * hidden_dim; ++i) { + float diff = std::abs(gpu_sorted_input[i] - cpu_sorted_input[i]); + if (diff > 1e-5) { + std::cout << "❌ Data Mismatch at index " << i + << ": CPU=" << cpu_sorted_input[i] << ", GPU=" << gpu_sorted_input[i] << std::endl; + data_ok = false; + if (i > 10) break; // 防止刷屏 + } + } + + if (data_ok) std::cout << "✅ Sorted Data Verification Passed!" << std::endl; + else std::cout << "❌ Data Verification Failed." << std::endl; +} + + +// ================================================================================= +// 4. Main +// ================================================================================= +int main() { + // 设置参数 + const int num_tokens = 16; // 少量 Token 用于调试 + const int hidden_dim = 8; // 必须是 4 的倍数以便测试 float4 + const int top_k = 2; + const int num_experts = 4; + const int total_tasks = num_tokens * top_k; + + std::cout << ">>> Running MoE Sort Test..." << std::endl; + std::cout << "Tokens: " << num_tokens << ", Hidden: " << hidden_dim + << ", TopK: " << top_k << ", Experts: " << num_experts << std::endl; + + // 1. Host 准备数据 + std::vector h_input(num_tokens * hidden_dim); + std::vector h_topk_ids(total_tasks); + + // 初始化 Input: Token i 的第 j 个数为 i + 0.01*j + for (int i = 0; i < num_tokens; ++i) { + for (int j = 0; j < hidden_dim; ++j) { + h_input[i * hidden_dim + j] = (float)i + 0.01f * j; + } + } + + // 初始化 Indices: 简单的循环模式 0, 1, 2, 3... + for (int i = 0; i < total_tasks; ++i) { + h_topk_ids[i] = i % num_experts; + } + + // 2. Device 分配内存 + float *d_input, *d_sorted_input; + int32_t *d_topk_ids, *d_expert_counts, *d_expert_offsets, *d_sorted_row_map; + + CUDA_CHECK(cudaMalloc(&d_input, h_input.size() * sizeof(float))); + CUDA_CHECK(cudaMalloc(&d_topk_ids, h_topk_ids.size() * sizeof(int32_t))); + CUDA_CHECK(cudaMalloc(&d_sorted_input, total_tasks * hidden_dim * sizeof(float))); + CUDA_CHECK(cudaMalloc(&d_expert_counts, num_experts * sizeof(int32_t))); + CUDA_CHECK(cudaMalloc(&d_expert_offsets, (num_experts + 1) * sizeof(int32_t))); + CUDA_CHECK(cudaMalloc(&d_sorted_row_map, total_tasks * sizeof(int32_t))); + + // 3. 拷贝数据到 GPU + CUDA_CHECK(cudaMemcpy(d_input, h_input.data(), h_input.size() * sizeof(float), cudaMemcpyHostToDevice)); + CUDA_CHECK(cudaMemcpy(d_topk_ids, h_topk_ids.data(), h_topk_ids.size() * sizeof(int32_t), cudaMemcpyHostToDevice)); + + // 4. 执行 Sort (Count + Scan) + cudaStream_t stream = 0; + launch_moe_sort(d_topk_ids, d_expert_counts, d_expert_offsets, num_tokens, top_k, num_experts, stream); + + // 5. 执行 Permute + launch_moe_permute(d_input, d_topk_ids, d_expert_offsets, d_sorted_input, d_sorted_row_map, d_expert_counts, + num_tokens, top_k, hidden_dim, num_experts, stream); + + CUDA_CHECK(cudaStreamSynchronize(stream)); + + // 6. 拷回结果 + std::vector h_offsets_gpu(num_experts + 1); + std::vector h_sorted_input_gpu(total_tasks * hidden_dim); + std::vector h_row_map_gpu(total_tasks); + + CUDA_CHECK(cudaMemcpy(h_offsets_gpu.data(), d_expert_offsets, h_offsets_gpu.size() * sizeof(int32_t), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_sorted_input_gpu.data(), d_sorted_input, h_sorted_input_gpu.size() * sizeof(float), cudaMemcpyDeviceToHost)); + CUDA_CHECK(cudaMemcpy(h_row_map_gpu.data(), d_sorted_row_map, h_row_map_gpu.size() * sizeof(int32_t), cudaMemcpyDeviceToHost)); + + // 7. 打印部分结果用于调试 + std::cout << "\n[GPU Offsets]: "; + for (auto v : h_offsets_gpu) std::cout << v << " "; + std::cout << std::endl; + + std::cout << "[GPU Row Map (First 10)]: "; + for (int i=0; i>> Speedup: {t_torch_p / t_cpp_p:.2f}x") + + # --- Decode Test --- + print(f"\n[Test 2] DECODE Phase (Small Batch / Latency Critical)") + print(f"Case: {DECODE_TESTCASES}") + + t_torch_d = benchmark_moe(torch_model, DECODE_TESTCASES, device, dtype, "PyTorch") + t_cpp_d = benchmark_moe(cpp_model, DECODE_TESTCASES, device, dtype, "C++") + print(f" >>> Speedup: {t_torch_d / t_cpp_d:.2f}x") + + print("\n" + "=" * 80) \ No newline at end of file