Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
[submodule "third_party/spdlog"]
path = third_party/spdlog
url = https://github.com/gabime/spdlog.git
[submodule "moe_gu_ops/src/nvidia_kernels/cutlass"]
path = moe_gu_ops/src/nvidia_kernels/cutlass
url = https://github.com/NVIDIA/cutlass.git
108 changes: 108 additions & 0 deletions moe_gu_ops/pybind_gumoe.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
#include <torch/extension.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h> // 必须包含,用于自动转换 map 和 string

#include "gu_moe.h" // MoE 头文件
#include "infinicore/tensor.hpp"
#include "infinicore/device.hpp"

namespace py = pybind11;

// 1. 转换器:Torch Tensor -> Infini Tensor (Zero-Copy)
// 这个函数只创建一个"视图",不拷贝数据。
// 安全性:因为 Module::load_state_dict 内部会执行 copy_from,所以数据会被安全地拷贝到模型参数里。
infinicore::Tensor torch_to_infini_view(const torch::Tensor& t) {
// 1. 获取形状
infinicore::Shape shape;
for (auto s : t.sizes()) shape.push_back(s);

// 2. 获取数据类型 (目前代码只支持 F32)
infinicore::DataType dtype = infinicore::DataType::F32;
if (t.dtype() == torch::kFloat32) dtype = infinicore::DataType::F32;
else if (t.dtype() == torch::kFloat16) dtype = infinicore::DataType::F16;
else throw std::runtime_error("Unsupported dtype");

// 3. 获取设备
infinicore::Device::Type dev_type = infinicore::Device::Type::CPU;
int dev_id = 0;
if (t.is_cuda()) {
dev_type = infinicore::Device::Type::NVIDIA;
dev_id = t.device().index();
}

// 4. 创建 Tensor 视图 (from_blob)
return infinicore::Tensor::from_blob(
t.data_ptr(),
shape,
dtype,
infinicore::Device(dev_type, dev_id)
);
}

// =====================================================================
// 2. 转换器:Infini Tensor -> Torch Tensor (用于 Forward 输出)
// =====================================================================
torch::Tensor infini_to_torch_copy(infinicore::Tensor t) {
std::vector<int64_t> sizes;
for (auto s : t->shape()) sizes.push_back(s);

auto options = torch::TensorOptions().dtype(torch::kFloat32); // 假设输出 F32
if (t->device().getType() == infinicore::Device::Type::NVIDIA) {
options = options.device(torch::kCUDA, t->device().getIndex());
} else {
options = options.device(torch::kCPU);
}

// 创建并 clone,确保拥有内存
return torch::from_blob(t->data(), sizes, options).clone();
}

// =====================================================================
// 3. 包装类 (Wrapper)
// =====================================================================
class PyGuMoeWrapper {
public:
std::shared_ptr<infinicore::nn::GuMoeSparseMoeBlock> moe_block;

// 构造函数:接收 Python 传来的参数
PyGuMoeWrapper(int num_experts, int hidden_dim, int intermediate_dim, int top_k, bool norm_topk) {
// 假设这里强制使用 NVIDIA:0,你可以根据需要添加 device 参数
infinicore::Device device(infinicore::Device::Type::NVIDIA, 0);

moe_block = std::make_shared<infinicore::nn::GuMoeSparseMoeBlock>(
num_experts, hidden_dim, intermediate_dim, top_k, norm_topk,
infinicore::DataType::F32, device
);
}

// Forward
torch::Tensor forward(torch::Tensor hidden_states) {
auto infini_input = torch_to_infini_view(hidden_states);
auto infini_output = moe_block->forward(infini_input);
return infini_to_torch_copy(infini_output);
}

// 【核心】加载权重接口
// 接收 Python 的 Dict[str, Tensor]
void load_state_dict(std::map<std::string, torch::Tensor> weights) {
std::unordered_map<std::string, infinicore::Tensor> infini_dict;

for (auto const& [name, tensor] : weights) {
infini_dict.emplace(name, torch_to_infini_view(tensor.contiguous()));
}
moe_block->load_state_dict(infini_dict);

std::cout << "[C++] load_state_dict finished. Loaded " << infini_dict.size() << " tensors." << std::endl;
}
};

// =====================================================================
// 4. 定义 Python 模块
// =====================================================================
PYBIND11_MODULE(gu_moe_ops, m) {
py::class_<PyGuMoeWrapper>(m, "GuMoeBlock")
.def(py::init<int, int, int, int, bool>())
.def("forward", &PyGuMoeWrapper::forward)
.def("load_state_dict", &PyGuMoeWrapper::load_state_dict);
}

99 changes: 99 additions & 0 deletions moe_gu_ops/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
import sys

try:
import torch
from torch.utils import cpp_extension
def no_op_check(compiler_name, compiler_version):
return True
cpp_extension._check_cuda_version = no_op_check
os.environ["TORCH_DONT_CHECK_CUDA_VERSION_COMPATIBILITY"] = "1"
except ImportError:
pass
# =================================================================================

if hasattr(torch, '_C'):
torch._C._GLIBCXX_USE_CXX11_ABI = True

from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import pybind11

# 获取当前目录 (即 .../InfiniCore/gu_moe_ops)
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))

# -------------------------------------------------------------------------
# 1. 路径配置 (动态推导)
# -------------------------------------------------------------------------
# 优先使用环境变量,如果没有,则根据目录结构自动推导
# 假设结构是:
# InfiniCore/
# ├── gu_moe_ops/ <-- 我们在这里
# ├── include/
# └── build/
INFINI_SRC_ROOT = os.getenv("INFINICORE_ROOT")
if not INFINI_SRC_ROOT:
# 向上找一级,就是 InfiniCore 根目录
INFINI_SRC_ROOT = os.path.abspath(os.path.join(CURRENT_DIR, "../"))

INFINI_LIB_DIR = os.getenv("INFINICORE_BUILD_DIR")
if not INFINI_LIB_DIR:
# 默认构建路径
INFINI_LIB_DIR = os.path.join(INFINI_SRC_ROOT, 'build/linux/x86_64/release')

print(f"[Info] InfiniCore Root: {INFINI_SRC_ROOT}")
print(f"[Info] InfiniCore Lib Dir: {INFINI_LIB_DIR}")

# 检查一下库目录是否存在,防止后面报错看不懂
if not os.path.exists(INFINI_LIB_DIR):
print(f"[Warning] Library directory not found: {INFINI_LIB_DIR}")
print("Please check if InfiniCore is built or set INFINICORE_BUILD_DIR env var.")

# -------------------------------------------------------------------------
# 2. 库列表
# -------------------------------------------------------------------------
libs = [
os.path.join(INFINI_LIB_DIR, 'libinfini-utils.a'),
os.path.join(INFINI_LIB_DIR, 'libinfiniop-nvidia.a'),
os.path.join(INFINI_LIB_DIR, 'libinfiniccl-nvidia.a'),
os.path.join(INFINI_LIB_DIR, 'libinfinirt-nvidia.a')
]

setup(
name='gu_moe_ops',
version='0.1.0',
ext_modules=[
CUDAExtension(
name='gu_moe_ops',
sources=[
'pybind_gumoe.cc',
'src/gumoe.cpp',
'src/gu_mul.cc',
'src/gu_topk_softmax.cc',
'src/nvidia_kernels/gu_reduce.cu',
'src/nvidia_kernels/gu_sort.cu',
],
include_dirs=[
pybind11.get_include(),
os.path.join(INFINI_SRC_ROOT, 'include'), # 动态引用父目录的 include
os.path.join(CURRENT_DIR, 'src'),
"/usr/local/cuda/include"
],
extra_objects=libs,

extra_link_args=[
'-Wl,--allow-shlib-undefined'
],

extra_compile_args={
'cxx': ['-O3', '-std=c++17', '-D_GLIBCXX_USE_CXX11_ABI=1'],
'nvcc': [
'-O3', '--use_fast_math',
# 注意:移除了硬编码的 -gencode 参数
# 编译时会自动读取环境变量 TORCH_CUDA_ARCH_LIST
]
}
)
],
cmdclass={ 'build_ext': BuildExtension }
)
69 changes: 69 additions & 0 deletions moe_gu_ops/src/gu_moe.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
#ifndef GU_MOE_H
#define GU_MOE_H

#include <vector>
#include <string>
#include <memory>
#include "infinicore/tensor.hpp"
#include "infinicore/nn/module.hpp"
#include "infinicore/device.hpp"
#include "infinicore/ops.hpp"

namespace infinicore::nn {

// Router
class GuMoeTopkRounter : public Module {
private:
int top_k_;
int num_experts_;
int hidden_dim_;
bool norm_topk_prob_;
Parameter weight_;
infiniopHandle_t handle_ = nullptr;
public:
GuMoeTopkRounter(int num_experts, int hidden_dim, int top_k, bool norm_topk_prob,
const DataType &dtype, const Device &device);
~GuMoeTopkRounter();
void set_weight(Tensor w); // 设置路由权重
std::pair<Tensor, Tensor> forward(const Tensor &hidden_states) const;
};

// Experts
class GuMoeExperts : public Module {
private:
int num_experts_;
int hidden_dim_;
int intermediate_dim_;
Parameter gate_up_proj_;
Parameter down_proj_;
infiniopHandle_t handle_ = nullptr;
Device device_;

public:
GuMoeExperts(int num_experts, int hidden_dim, int intermediate_dim,
const DataType& dtype, const Device& device);
~GuMoeExperts();
void set_weights(Tensor gate_up, Tensor down); // 设置专家权重
Tensor forward(const Tensor& hidden_states, const Tensor& top_k_index, const Tensor& top_k_values) const;
};

// Block (MoE 整体)
class GuMoeSparseMoeBlock : public Module {
private:
std::shared_ptr<GuMoeTopkRounter> router_;
std::shared_ptr<GuMoeExperts> experts_;

public:
GuMoeSparseMoeBlock(int num_experts, int hidden_dim, int intermediate_dim,
int top_k, bool norm_topk,
const DataType& dtype, const Device& device);

// ✅ 统一设置权重的接口
void set_weights(Parameter gate_up, Parameter down, Parameter router_weight);

// ✅ 关键:Forward 只需要 Input (Router 在内部计算)
Tensor forward(const Tensor& hidden_states);
};

} // namespace
#endif
49 changes: 49 additions & 0 deletions moe_gu_ops/src/gu_mul.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#include "infinicore/tensor.hpp"
#include "infinicore/context/context.hpp"
#include "infiniop/ops/mul.h"
#include "gu_mul.h"

namespace infinicore::op {

// 【修改点 1】函数签名增加 infiniopHandle_t handle
Tensor mul(Tensor a, Tensor b, infiniopHandle_t handle) {

Tensor c = Tensor::empty(a->shape(), a->dtype(), a->device());

infiniopMulDescriptor_t desc;
infiniopCreateMulDescriptor(
handle,
&desc,
c->desc(),
a->desc(),
b->desc()
);

size_t workspace_size = 0;
std::shared_ptr<infinicore::Memory> workspace_mem = nullptr;
void* workspace = nullptr;
infiniopGetMulWorkspaceSize(desc, &workspace_size);

if (workspace_size > 0) {
workspace_mem = context::allocateMemory(workspace_size);
workspace = workspace_mem->data();
}

void* stream = nullptr;

infiniopMul(
desc,
workspace,
workspace_size,
c->data(),
a->data(),
b->data(),
stream
);

infiniopDestroyMulDescriptor(desc);

return c;
}

} // namespace infinicore::op
14 changes: 14 additions & 0 deletions moe_gu_ops/src/gu_mul.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#ifndef GU_MUL_H
#define GU_MUL_H

#include "infinicore/tensor.hpp"
#include "infinicore/context/context.hpp"
#include "infiniop/ops/mul.h"

namespace infinicore::op {

Tensor mul(Tensor a, Tensor b, infiniopHandle_t handle);

} // namespace infinicore::op

#endif // GU_MUL_H
Loading