-
Notifications
You must be signed in to change notification settings - Fork 13.7k
[metal] Add fused RMS_NORM + MUL + SWIGLU for Qwen3Next #16143
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
- Added F=4 fusion in kernel_rms_norm_fuse_impl
|
Hi @ggerganov ! First of all, thanks a lot for the ongoing Metal-backend refactor—really appreciate the effort you put into it! I noticed that the Metal backend has undergone two rounds of architectural changes:
Could you let me know if the Metal-backend refactor is now considered complete? Is it a good time to add a new branch to the RMS_NORM fusion pipeline? My plan is to introduce an extra
Since the community is implementing Qwen3-Next support on the CPU side, I'd like to provide the corresponding Metal optimization. Below is the “golden-reference” generation and test code from my earlier draft. If the refactor is stable, I'll port it to the latest upstream and submit a new PR. Please let me know your thoughts when you have a moment. I'm ready to proceed once the refactor is stable. generate_reference.pyimport numpy as np
import torch
import torch.nn.functional as F
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextRMSNormGated
def generate_test_cases():
"""生成多组测试用例,包含权重张量和门控张量"""
test_cases = [
{"shape": (3, 4, 5, 64), "eps": 1e-6, "name": "standard"},
{"shape": (1, 1, 1, 128), "eps": 1e-5, "name": "simple"},
{"shape": (2, 8, 16, 32), "eps": 1e-4, "name": "medium"},
{"shape": (1, 2, 3, 4), "eps": 1e-4, "name": "work"},
]
for case in test_cases:
shape = case["shape"]
eps = case["eps"]
name = case["name"]
hidden_size = shape[-1] # RMSNorm作用于最后一个维度
# 设置种子确保可重现
np.random.seed(43)
torch.manual_seed(43)
# 生成随机输入和门控张量
x = np.random.uniform(-5.0, 5.0, shape).astype(np.float32)
gate = np.random.uniform(-2.0, 2.0, shape).astype(np.float32) # 门控张量
# 创建Qwen3NextRMSNormGated层并生成随机权重
rms_norm_gated = Qwen3NextRMSNormGated(hidden_size, eps=eps)
# 为权重生成随机值(通常初始化为1,但这里生成随机值用于测试)
with torch.no_grad():
rms_norm_gated.weight.data = torch.randn(hidden_size) * 0.1 + 1.0 # 围绕1.0的小扰动
# PyTorch 计算
pt_x = torch.from_numpy(x)
pt_gate = torch.from_numpy(gate)
pt_out = rms_norm_gated(pt_x, gate=pt_gate)
# 保存输入张量
x.tofile(f"input_{name}.bin")
# 保存门控张量
gate.tofile(f"gate_{name}.bin")
# 保存权重张量
rms_norm_gated.weight.detach().numpy().astype(np.float32).tofile(f"weight_{name}.bin")
# 保存输出张量
pt_out.detach().numpy().tofile(f"expected_{name}.bin")
# 保存元数据
with open(f"meta_{name}.txt", "w") as f:
f.write(f"shape: {shape}\n")
f.write(f"eps: {eps}\n")
f.write(f"hidden_size: {hidden_size}\n")
f.write(f"elements: {np.prod(shape)}\n")
f.write(f"weight_elements: {hidden_size}\n")
f.write(f"gate_elements: {np.prod(shape)}\n")
print(f"Generated test case '{name}': shape={shape}, eps={eps}, hidden_size={hidden_size}")
print(f" Files: input_{name}.bin, gate_{name}.bin, weight_{name}.bin, expected_{name}.bin, meta_{name}.txt")
if __name__ == "__main__":
generate_test_cases()test_metal_precision.cpp#include <ggml.h>
#include <ggml-backend.h>
#include <ggml-alloc.h>
#include <ggml-metal.h>
#include <vector>
#include <string>
#include <fstream>
#include <iostream>
#include <cmath>
#include <iomanip>
#include <cstring>
#include <cstdlib>
#include <unistd.h>
struct TestCase {
std::string name;
std::array<int64_t, 4> shape;
float eps;
size_t elements;
size_t weight_elements;
size_t gate_elements; // 新增门控张量元素数量
};
struct TestConfig {
std::string data_path = "./";
std::vector<std::string> test_cases = {"standard", "simple", "medium","work"};
// std::vector<std::string> test_cases = {"work"};
double tolerance = 1e-5;
};
// 从文件读取二进制数据
std::vector<float> load_binary_data(const std::string& filename, size_t expected_size) {
std::ifstream file(filename, std::ios::binary);
if (!file) {
throw std::runtime_error("Cannot open file: " + filename);
}
std::vector<float> data(expected_size);
file.read(reinterpret_cast<char*>(data.data()), expected_size * sizeof(float));
if (file.gcount() != expected_size * sizeof(float)) {
throw std::runtime_error("File size mismatch: " + filename);
}
return data;
}
// 解析元数据文件
TestCase parse_metadata(const std::string& name) {
std::ifstream file(name);
if (!file) {
throw std::runtime_error("Cannot open metadata file for: " + name);
}
TestCase test_case;
test_case.name = name;
int64_t pytorch_shape[4];
std::string line;
while (std::getline(file, line)) {
if (line.find("shape: ") == 0) {
sscanf(line.c_str(), "shape: (%lld, %lld, %lld, %lld)",
&pytorch_shape[0], &pytorch_shape[1],
&pytorch_shape[2], &pytorch_shape[3]);
// 转换为GGML维度顺序 (width, height, seq_len, batch)
test_case.shape[0] = pytorch_shape[3]; // width
test_case.shape[1] = pytorch_shape[2]; // height
test_case.shape[2] = pytorch_shape[1]; // seq_len
test_case.shape[3] = pytorch_shape[0]; // batch
} else if (line.find("eps: ") == 0) {
sscanf(line.c_str(), "eps: %f", &test_case.eps);
} else if (line.find("elements: ") == 0) {
sscanf(line.c_str(), "elements: %zu", &test_case.elements);
} else if (line.find("weight_elements: ") == 0) {
sscanf(line.c_str(), "weight_elements: %zu", &test_case.weight_elements);
} else if (line.find("gate_elements: ") == 0) {
sscanf(line.c_str(), "gate_elements: %zu", &test_case.gate_elements);
}
}
return test_case;
}
// 计算精度指标
struct PrecisionMetrics {
double max_abs_error;
double mean_abs_error;
double rmse;
double relative_error;
bool passed;
};
PrecisionMetrics calculate_metrics(const std::vector<float>& ggml_output,
const std::vector<float>& pytorch_output,
double tolerance = 1e-5) {
PrecisionMetrics metrics = {};
if (ggml_output.size() != pytorch_output.size()) {
throw std::runtime_error("Output size mismatch");
}
double sum_abs_error = 0.0;
double sum_squared_error = 0.0;
double sum_pytorch_squared = 0.0;
for (size_t i = 0; i < ggml_output.size(); ++i) {
double abs_error = std::abs(ggml_output[i] - pytorch_output[i]);
double squared_error = abs_error * abs_error;
metrics.max_abs_error = std::max(metrics.max_abs_error, abs_error);
sum_abs_error += abs_error;
sum_squared_error += squared_error;
sum_pytorch_squared += pytorch_output[i] * pytorch_output[i];
}
metrics.mean_abs_error = sum_abs_error / ggml_output.size();
metrics.rmse = std::sqrt(sum_squared_error / ggml_output.size());
metrics.relative_error = std::sqrt(sum_squared_error / sum_pytorch_squared);
metrics.passed = metrics.max_abs_error < tolerance;
return metrics;
}
// 创建Qwen3NextRMSNormGated操作:RMSNorm(x) * silu(gate) * weight
ggml_tensor* create_qwen3_next_rms_norm_gated(ggml_context* ctx, ggml_tensor* input,
ggml_tensor* gate, ggml_tensor* weight, float eps) {
// 步骤1: RMS归一化
ggml_tensor* rms_norm_result = ggml_rms_norm(ctx, input, eps);
// 步骤2: 对门控张量应用SiLU激活函数
ggml_tensor* silu_gate = ggml_silu(ctx, gate);
// 步骤3: RMS归一化结果与SiLU门控相乘
ggml_tensor* gated_result = ggml_mul(ctx, rms_norm_result, silu_gate);
// 步骤4: 与权重张量相乘
ggml_tensor* final_result = ggml_mul(ctx, gated_result, weight);
return final_result;
}
// 使用SwiGLU精简的实现
ggml_tensor* create_qwen3_next_rms_norm_gated_optimized(ggml_context* ctx, ggml_tensor* input,
ggml_tensor* gate, ggml_tensor* weight, float eps) {
// 步骤1: RMS归一化
ggml_tensor* rms_norm_result = ggml_rms_norm(ctx, input, eps);
// 步骤2: 使用SwiGLU融合SiLU门控和乘法操作
ggml_tensor* swiglu_result = ggml_swiglu_split(ctx, gate, rms_norm_result);
// 步骤3: 与权重张量相乘
ggml_tensor* final_result = ggml_mul(ctx, swiglu_result, weight);
return final_result;
}
// 优化的Qwen3NextRMSNormGated实现:先norm再weight再swiglu_split
ggml_tensor* create_qwen3_next_rms_norm_gated_optimized_order(ggml_context* ctx, ggml_tensor* input,
ggml_tensor* gate, ggml_tensor* weight, float eps) {
// 步骤1: RMS归一化
ggml_tensor* rms_norm_result = ggml_rms_norm(ctx, input, eps);
// 步骤2: 与权重张量相乘 (可能被Metal后端融合为RMS_NORM_MUL)
ggml_tensor* weighted_result = ggml_mul(ctx, rms_norm_result, weight);
// 步骤3: 使用SwiGLU处理门控 - gate经过SiLU激活后与weighted_result相乘
ggml_tensor* final_result = ggml_swiglu_split(ctx, gate, weighted_result);
return final_result;
}
// 运行单个测试用例
bool run_test_case(const std::string& test_name, ggml_backend_t backend, const TestConfig& config) {
try {
// 解析测试用例元数据
TestCase test_case = parse_metadata(config.data_path + "meta_" + test_name + ".txt");
// 加载输入、门控、权重和期望输出
auto input_data = load_binary_data(config.data_path + "input_" + test_name + ".bin", test_case.elements);
auto gate_data = load_binary_data(config.data_path + "gate_" + test_name + ".bin", test_case.gate_elements);
auto weight_data = load_binary_data(config.data_path + "weight_" + test_name + ".bin", test_case.weight_elements);
auto expected_output = load_binary_data(config.data_path + "expected_" + test_name + ".bin", test_case.elements);
// 创建 GGML 上下文
ggml_init_params params = {
.mem_size = 64 * 1024 * 1024, // 64MB (增加内存以支持更多张量)
.mem_buffer = nullptr,
.no_alloc = true,
};
ggml_context* ctx = ggml_init(params);
if (!ctx) {
throw std::runtime_error("Failed to initialize GGML context");
}
// 构建计算图
ggml_tensor* input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, test_case.shape.data());
ggml_tensor* gate = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, test_case.shape.data()); // 门控张量与输入同形状
// 权重张量形状:权重应该在GGML的第0维(对应PyTorch的最后一维)
int64_t weight_shape[4] = {static_cast<int64_t>(test_case.weight_elements), 1, 1, 1};
ggml_tensor* weight = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, weight_shape);
// 使用Qwen3NextRMSNormGated操作
// ggml_tensor* output = create_qwen3_next_rms_norm_gated(ctx, input, gate, weight, test_case.eps);
// ggml_tensor* output = create_qwen3_next_rms_norm_gated_optimized(ctx, input, gate, weight, test_case.eps);
ggml_tensor* output = create_qwen3_next_rms_norm_gated_optimized_order(ctx, input, gate, weight, test_case.eps);
ggml_cgraph* graph = ggml_new_graph(ctx);
ggml_build_forward_expand(graph, output);
// 分配内存
ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);
if (!buffer) {
throw std::runtime_error("Failed to allocate backend buffer");
}
// 设置输入数据、门控数据和权重数据
ggml_backend_tensor_set(input, input_data.data(), 0, input_data.size() * sizeof(float));
ggml_backend_tensor_set(gate, gate_data.data(), 0, gate_data.size() * sizeof(float));
ggml_backend_tensor_set(weight, weight_data.data(), 0, weight_data.size() * sizeof(float));
// 执行计算
ggml_status status = ggml_backend_graph_compute(backend, graph);
if (status != GGML_STATUS_SUCCESS) {
throw std::runtime_error("Graph computation failed");
}
// // 👇👇👇 关键修复:立即同步,确保 GPU 开始工作 👇👇👇
// ggml_backend_synchronize(backend); // 确保提交的命令已开始执行
// // 👇👇👇 关键修复:GPU 计算刚完成,立即暂停等待你捕获 Frame 👇👇👇
// printf(">>> Test case '%s' GPU 计算已完成!sleep 10秒,请立即点击 Xcode 📷 Capture GPU Frame 按钮!<<<\n", test_name.c_str());
// sleep(3); // 给你 5 秒钟点击 Capture 按钮
// // 👆👆👆 关键修复结束 👆👆👆
// printf(">>> 程序继续运行... <<<\n");
// 获取输出结果
std::vector<float> ggml_output(test_case.elements);
ggml_backend_tensor_get(output, ggml_output.data(), 0, ggml_output.size() * sizeof(float));
// 计算精度指标
PrecisionMetrics metrics = calculate_metrics(ggml_output, expected_output, config.tolerance);
// 输出结果
std::cout << "\n=== Test Case: " << test_name << " ===" << std::endl;
std::cout << "GGML Shape: [" << test_case.shape[0] << ", " << test_case.shape[1]
<< ", " << test_case.shape[2] << ", " << test_case.shape[3] << "]" << std::endl;
std::cout << "EPS: " << test_case.eps << std::endl;
std::cout << "Weight Elements: " << test_case.weight_elements << std::endl;
std::cout << "Gate Elements: " << test_case.gate_elements << std::endl;
std::cout << std::fixed << std::setprecision(8);
std::cout << "Max Absolute Error: " << metrics.max_abs_error << std::endl;
std::cout << "Mean Absolute Error: " << metrics.mean_abs_error << std::endl;
std::cout << "RMSE: " << metrics.rmse << std::endl;
std::cout << "Relative Error: " << metrics.relative_error << std::endl;
std::cout << "Status: " << (metrics.passed ? "PASS" : "FAIL") << std::endl;
// 清理资源
ggml_backend_buffer_free(buffer);
ggml_free(ctx);
return metrics.passed;
} catch (const std::exception& e) {
std::cerr << "Error in test case " << test_name << ": " << e.what() << std::endl;
return false;
}
}
TestConfig parse_args(int argc, char* argv[]) {
TestConfig config;
const char* env_data_path = std::getenv("GGML_TEST_DATA_PATH");
if (env_data_path) {
config.data_path = env_data_path;
if (!config.data_path.empty() && config.data_path.back() != '/') {
config.data_path += '/';
}
}
for (int i = 1; i < argc; i++) {
if (strcmp(argv[i], "--data-path") == 0 && i + 1 < argc) {
config.data_path = argv[++i];
if (!config.data_path.empty() && config.data_path.back() != '/') {
config.data_path += '/';
}
} else if (strcmp(argv[i], "--tolerance") == 0 && i + 1 < argc) {
config.tolerance = std::stod(argv[++i]);
} else if (strcmp(argv[i], "--test-case") == 0 && i + 1 < argc) {
// 允许指定特定的测试用例
config.test_cases.clear();
config.test_cases.push_back(argv[++i]);
} else if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) {
std::cout << "Usage: " << argv[0] << " [options]\n";
std::cout << "Options:\n";
std::cout << " --data-path PATH Path to test data files (default: ./)\n";
std::cout << " --tolerance VALUE Error tolerance for tests (default: 1e-5)\n";
std::cout << " --test-case NAME Run specific test case (default: all)\n";
std::cout << " --help, -h Show this help message\n";
exit(0);
} else {
std::cerr << "Unknown argument: " << argv[i] << std::endl;
std::cerr << "Use --help for usage information" << std::endl;
exit(1);
}
}
return config;
}
int main(int argc, char* argv[]) {
TestConfig config = parse_args(argc, argv);
// 初始化Metal后端
ggml_backend_t backend = ggml_backend_metal_init();
if (!backend) {
std::cerr << "Failed to initialize Metal backend" << std::endl;
return 1;
}
int passed = 0;
int total = config.test_cases.size();
std::cout << "Running GGML Metal vs PyTorch Qwen3NextRMSNormGated Precision Tests" << std::endl;
std::cout << "Data Path: " << config.data_path << std::endl;
std::cout << "Tolerance: " << config.tolerance << std::endl;
std::cout << "Test Cases: " << total << std::endl;
std::cout << "Note: PyTorch dimensions are automatically converted to GGML order" << std::endl;
std::cout << "================================================" << std::endl;
// const char* metal_path = getenv("GGML_METAL_PATH");
// if (metal_path) {
// printf(">>> GGML_METAL_PATH = %s\n", metal_path);
// } else {
// printf(">>> GGML_METAL_PATH not set!\n");
// }
// 运行所有测试用例
for (const auto& test_name : config.test_cases) {
if (run_test_case(test_name, backend, config)) {
passed++;
}
}
// 输出测试总结
std::cout << "\n=== Test Summary ===" << std::endl;
std::cout << "Passed: " << passed << "/" << total << std::endl;
std::cout << "Success Rate: " << (100.0 * passed / total) << "%" << std::endl;
if (passed == total) {
std::cout << "All tests PASSED!" << std::endl;
} else {
std::cout << "Some tests FAILED!" << std::endl;
}
ggml_backend_free(backend);
return (passed == total) ? 0 : 1;
} |

This PR adds a new fused operator in the Metal backend to support
Qwen3NextRMSNormGatedfrom Qwen3-Next models:kernel_rms_norm_fuse_implwithF == 4: fusesrms_norm + mul + swiglu_splitinto a single kernel.Qwen3NextRMSNormGatedinmodeling_qwen3_next.py.SWIGLU_SPLITas a fusible tail op.test_rms_norm_mul_addtest class to cover the new fusion pattern.test-backend-opswith full parameter coverage (broadcast, eps, multi_add).Tested on Apple M4 Pro. All tests pass.
Related to #15940