Skip to content

Conversation

@MemoryIt
Copy link

This PR adds a new fused operator in the Metal backend to support Qwen3NextRMSNormGated from Qwen3-Next models:

  • Extends kernel_rms_norm_fuse_impl with F == 4: fuses rms_norm + mul + swiglu_split into a single kernel.
  • Matches the behavior of Qwen3NextRMSNormGated in modeling_qwen3_next.py.
  • Extends fusion detection logic to recognize SWIGLU_SPLIT as a fusible tail op.
  • Expands test_rms_norm_mul_add test class to cover the new fusion pattern.
  • ✅ Passes test-backend-ops with full parameter coverage (broadcast, eps, multi_add).
  • ✅ Passes local CI on macOS (CPU + Metal backends, numerical consistency verified).

Tested on Apple M4 Pro. All tests pass.

Related to #15940

@github-actions github-actions bot added testing Everything test related ggml changes relating to the ggml tensor library for machine learning Apple Metal https://en.wikipedia.org/wiki/Metal_(API) labels Sep 21, 2025
@MemoryIt
Copy link
Author

Hi @ggerganov !

First of all, thanks a lot for the ongoing Metal-backend refactor—really appreciate the effort you put into it!

I noticed that the Metal backend has undergone two rounds of architectural changes:

  1. The first occurred right after I successfully fused RMS_NORM + MUL + SWIGLU and opened my initial PR.
  2. The second one I saw today: you extended the shape support of the RMS_NORM fusion kernel and unified the detection logic for norm and RMS_NORM fusion.

Could you let me know if the Metal-backend refactor is now considered complete? Is it a good time to add a new branch to the RMS_NORM fusion pipeline?

My plan is to introduce an extra branch parameter to ggml_metal_library_get_pipeline_norm so we can select:

  • Branch 0: the existing RMS_NORM + MUL + ADD
  • Branch 1: RMS_NORM + MUL + SWIGLU (the normalization used in Qwen3-Next)

Since the community is implementing Qwen3-Next support on the CPU side, I'd like to provide the corresponding Metal optimization. Below is the “golden-reference” generation and test code from my earlier draft. If the refactor is stable, I'll port it to the latest upstream and submit a new PR.

Please let me know your thoughts when you have a moment. I'm ready to proceed once the refactor is stable.

generate_reference.py
import numpy as np      
import torch      
import torch.nn.functional as F  
from transformers.models.qwen3_next.modeling_qwen3_next import Qwen3NextRMSNormGated  
      
def generate_test_cases():      
    """生成多组测试用例,包含权重张量和门控张量"""      
    test_cases = [      
        {"shape": (3, 4, 5, 64), "eps": 1e-6, "name": "standard"},      
        {"shape": (1, 1, 1, 128), "eps": 1e-5, "name": "simple"},      
        {"shape": (2, 8, 16, 32), "eps": 1e-4, "name": "medium"},
        {"shape": (1, 2, 3, 4), "eps": 1e-4, "name": "work"},      
    ]      
          
    for case in test_cases:      
        shape = case["shape"]      
        eps = case["eps"]      
        name = case["name"]      
        hidden_size = shape[-1]  # RMSNorm作用于最后一个维度    
              
        # 设置种子确保可重现      
        np.random.seed(43)      
        torch.manual_seed(43)      
              
        # 生成随机输入和门控张量  
        x = np.random.uniform(-5.0, 5.0, shape).astype(np.float32)      
        gate = np.random.uniform(-2.0, 2.0, shape).astype(np.float32)  # 门控张量  
              
        # 创建Qwen3NextRMSNormGated层并生成随机权重    
        rms_norm_gated = Qwen3NextRMSNormGated(hidden_size, eps=eps)    
            
        # 为权重生成随机值(通常初始化为1,但这里生成随机值用于测试)    
        with torch.no_grad():    
            rms_norm_gated.weight.data = torch.randn(hidden_size) * 0.1 + 1.0  # 围绕1.0的小扰动    
              
        # PyTorch 计算      
        pt_x = torch.from_numpy(x)      
        pt_gate = torch.from_numpy(gate)  
        pt_out = rms_norm_gated(pt_x, gate=pt_gate)      
              
        # 保存输入张量    
        x.tofile(f"input_{name}.bin")      
          
        # 保存门控张量  
        gate.tofile(f"gate_{name}.bin")  
            
        # 保存权重张量    
        rms_norm_gated.weight.detach().numpy().astype(np.float32).tofile(f"weight_{name}.bin")    
            
        # 保存输出张量    
        pt_out.detach().numpy().tofile(f"expected_{name}.bin")      
              
        # 保存元数据      
        with open(f"meta_{name}.txt", "w") as f:      
            f.write(f"shape: {shape}\n")      
            f.write(f"eps: {eps}\n")      
            f.write(f"hidden_size: {hidden_size}\n")    
            f.write(f"elements: {np.prod(shape)}\n")    
            f.write(f"weight_elements: {hidden_size}\n")  
            f.write(f"gate_elements: {np.prod(shape)}\n")  
              
        print(f"Generated test case '{name}': shape={shape}, eps={eps}, hidden_size={hidden_size}")    
        print(f"  Files: input_{name}.bin, gate_{name}.bin, weight_{name}.bin, expected_{name}.bin, meta_{name}.txt")    
      
if __name__ == "__main__":      
    generate_test_cases()
test_metal_precision.cpp
#include <ggml.h>      
#include <ggml-backend.h>      
#include <ggml-alloc.h>      
#include <ggml-metal.h>    
      
#include <vector>      
#include <string>      
#include <fstream>      
#include <iostream>      
#include <cmath>      
#include <iomanip>    
#include <cstring>    
#include <cstdlib>  

#include <unistd.h> 
  
struct TestCase {      
    std::string name;      
    std::array<int64_t, 4> shape;      
    float eps;      
    size_t elements;    
    size_t weight_elements;  
    size_t gate_elements;  // 新增门控张量元素数量  
};      
  
struct TestConfig {      
    std::string data_path = "./";    
    std::vector<std::string> test_cases = {"standard", "simple", "medium","work"};    
    // std::vector<std::string> test_cases = {"work"};   
    double tolerance = 1e-5;      
};      
  
// 从文件读取二进制数据      
std::vector<float> load_binary_data(const std::string& filename, size_t expected_size) {      
    std::ifstream file(filename, std::ios::binary);      
    if (!file) {      
        throw std::runtime_error("Cannot open file: " + filename);      
    }      
          
    std::vector<float> data(expected_size);      
    file.read(reinterpret_cast<char*>(data.data()), expected_size * sizeof(float));      
          
    if (file.gcount() != expected_size * sizeof(float)) {      
        throw std::runtime_error("File size mismatch: " + filename);      
    }      
          
    return data;      
}      
  
// 解析元数据文件  
TestCase parse_metadata(const std::string& name) {      
    std::ifstream file(name);      
    if (!file) {      
        throw std::runtime_error("Cannot open metadata file for: " + name);      
    }      
          
    TestCase test_case;      
    test_case.name = name;      
        
    int64_t pytorch_shape[4];    
          
    std::string line;      
    while (std::getline(file, line)) {      
        if (line.find("shape: ") == 0) {      
            sscanf(line.c_str(), "shape: (%lld, %lld, %lld, %lld)",       
                   &pytorch_shape[0], &pytorch_shape[1],       
                   &pytorch_shape[2], &pytorch_shape[3]);      
                
            // 转换为GGML维度顺序 (width, height, seq_len, batch)    
            test_case.shape[0] = pytorch_shape[3];  // width    
            test_case.shape[1] = pytorch_shape[2];  // height      
            test_case.shape[2] = pytorch_shape[1];  // seq_len    
            test_case.shape[3] = pytorch_shape[0];  // batch    
        } else if (line.find("eps: ") == 0) {      
            sscanf(line.c_str(), "eps: %f", &test_case.eps);      
        } else if (line.find("elements: ") == 0) {      
            sscanf(line.c_str(), "elements: %zu", &test_case.elements);      
        } else if (line.find("weight_elements: ") == 0) {      
            sscanf(line.c_str(), "weight_elements: %zu", &test_case.weight_elements);      
        } else if (line.find("gate_elements: ") == 0) {      
            sscanf(line.c_str(), "gate_elements: %zu", &test_case.gate_elements);      
        }  
    }      
          
    return test_case;      
}      
  
// 计算精度指标      
struct PrecisionMetrics {      
    double max_abs_error;      
    double mean_abs_error;      
    double rmse;      
    double relative_error;      
    bool passed;      
};      
  
PrecisionMetrics calculate_metrics(const std::vector<float>& ggml_output,       
                                 const std::vector<float>& pytorch_output,      
                                 double tolerance = 1e-5) {      
    PrecisionMetrics metrics = {};      
          
    if (ggml_output.size() != pytorch_output.size()) {      
        throw std::runtime_error("Output size mismatch");      
    }      
          
    double sum_abs_error = 0.0;      
    double sum_squared_error = 0.0;      
    double sum_pytorch_squared = 0.0;      
          
    for (size_t i = 0; i < ggml_output.size(); ++i) {      
        double abs_error = std::abs(ggml_output[i] - pytorch_output[i]);      
        double squared_error = abs_error * abs_error;      
              
        metrics.max_abs_error = std::max(metrics.max_abs_error, abs_error);      
        sum_abs_error += abs_error;      
        sum_squared_error += squared_error;      
        sum_pytorch_squared += pytorch_output[i] * pytorch_output[i];      
    }      
          
    metrics.mean_abs_error = sum_abs_error / ggml_output.size();      
    metrics.rmse = std::sqrt(sum_squared_error / ggml_output.size());      
    metrics.relative_error = std::sqrt(sum_squared_error / sum_pytorch_squared);      
    metrics.passed = metrics.max_abs_error < tolerance;      
          
    return metrics;      
}      
  
// 创建Qwen3NextRMSNormGated操作:RMSNorm(x) * silu(gate) * weight  
ggml_tensor* create_qwen3_next_rms_norm_gated(ggml_context* ctx, ggml_tensor* input,   
                                              ggml_tensor* gate, ggml_tensor* weight, float eps) {    
    // 步骤1: RMS归一化  
    ggml_tensor* rms_norm_result = ggml_rms_norm(ctx, input, eps);    
      
    // 步骤2: 对门控张量应用SiLU激活函数  
    ggml_tensor* silu_gate = ggml_silu(ctx, gate);  
      
    // 步骤3: RMS归一化结果与SiLU门控相乘  
    ggml_tensor* gated_result = ggml_mul(ctx, rms_norm_result, silu_gate);  
      
    // 步骤4: 与权重张量相乘  
    ggml_tensor* final_result = ggml_mul(ctx, gated_result, weight);  
        
    return final_result;    
} 

// 使用SwiGLU精简的实现  
ggml_tensor* create_qwen3_next_rms_norm_gated_optimized(ggml_context* ctx, ggml_tensor* input,   
                                                        ggml_tensor* gate, ggml_tensor* weight, float eps) {  
    // 步骤1: RMS归一化  
    ggml_tensor* rms_norm_result = ggml_rms_norm(ctx, input, eps);  
      
    // 步骤2: 使用SwiGLU融合SiLU门控和乘法操作  
    ggml_tensor* swiglu_result = ggml_swiglu_split(ctx, gate, rms_norm_result); 
      
    // 步骤3: 与权重张量相乘  
    ggml_tensor* final_result = ggml_mul(ctx, swiglu_result, weight);  
      
    return final_result;  
}

// 优化的Qwen3NextRMSNormGated实现:先norm再weight再swiglu_split  
ggml_tensor* create_qwen3_next_rms_norm_gated_optimized_order(ggml_context* ctx, ggml_tensor* input,   
                                                        ggml_tensor* gate, ggml_tensor* weight, float eps) {  
    // 步骤1: RMS归一化  
    ggml_tensor* rms_norm_result = ggml_rms_norm(ctx, input, eps);  
      
    // 步骤2: 与权重张量相乘 (可能被Metal后端融合为RMS_NORM_MUL)  
    ggml_tensor* weighted_result = ggml_mul(ctx, rms_norm_result, weight);  
      
    // 步骤3: 使用SwiGLU处理门控 - gate经过SiLU激活后与weighted_result相乘  
    ggml_tensor* final_result = ggml_swiglu_split(ctx, gate, weighted_result);  
      
    return final_result;  
}
  
// 运行单个测试用例      
bool run_test_case(const std::string& test_name, ggml_backend_t backend, const TestConfig& config) {      
    try {      
        // 解析测试用例元数据      
        TestCase test_case = parse_metadata(config.data_path + "meta_" + test_name + ".txt");       
              
        // 加载输入、门控、权重和期望输出      
        auto input_data = load_binary_data(config.data_path + "input_" + test_name + ".bin", test_case.elements);      
        auto gate_data = load_binary_data(config.data_path + "gate_" + test_name + ".bin", test_case.gate_elements);  
        auto weight_data = load_binary_data(config.data_path + "weight_" + test_name + ".bin", test_case.weight_elements);    
        auto expected_output = load_binary_data(config.data_path + "expected_" + test_name + ".bin", test_case.elements);      
       
        // 创建 GGML 上下文      
        ggml_init_params params = {      
            .mem_size = 64 * 1024 * 1024,  // 64MB (增加内存以支持更多张量)  
            .mem_buffer = nullptr,      
            .no_alloc = true,      
        };      
              
        ggml_context* ctx = ggml_init(params);      
        if (!ctx) {      
            throw std::runtime_error("Failed to initialize GGML context");      
        }      
              
        // 构建计算图    
        ggml_tensor* input = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, test_case.shape.data());    
        ggml_tensor* gate = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, test_case.shape.data());  // 门控张量与输入同形状  
            
        // 权重张量形状:权重应该在GGML的第0维(对应PyTorch的最后一维)    
        int64_t weight_shape[4] = {static_cast<int64_t>(test_case.weight_elements), 1, 1, 1};    
        ggml_tensor* weight = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, weight_shape);    
            
        // 使用Qwen3NextRMSNormGated操作    
        // ggml_tensor* output = create_qwen3_next_rms_norm_gated(ctx, input, gate, weight, test_case.eps);  
        // ggml_tensor* output = create_qwen3_next_rms_norm_gated_optimized(ctx, input, gate, weight, test_case.eps);  
        ggml_tensor* output = create_qwen3_next_rms_norm_gated_optimized_order(ctx, input, gate, weight, test_case.eps); 
              
        ggml_cgraph* graph = ggml_new_graph(ctx);      
        ggml_build_forward_expand(graph, output);      
              
        // 分配内存      
        ggml_backend_buffer_t buffer = ggml_backend_alloc_ctx_tensors(ctx, backend);      
        if (!buffer) {      
            throw std::runtime_error("Failed to allocate backend buffer");      
        }      
              
        // 设置输入数据、门控数据和权重数据    
        ggml_backend_tensor_set(input, input_data.data(), 0, input_data.size() * sizeof(float));    
        ggml_backend_tensor_set(gate, gate_data.data(), 0, gate_data.size() * sizeof(float));  
        ggml_backend_tensor_set(weight, weight_data.data(), 0, weight_data.size() * sizeof(float));    
              
        // 执行计算      
        ggml_status status = ggml_backend_graph_compute(backend, graph);      
        if (status != GGML_STATUS_SUCCESS) {      
            throw std::runtime_error("Graph computation failed");      
        }      
           
        // // 👇👇👇 关键修复:立即同步,确保 GPU 开始工作 👇👇👇
        // ggml_backend_synchronize(backend); // 确保提交的命令已开始执行
        // // 👇👇👇 关键修复:GPU 计算刚完成,立即暂停等待你捕获 Frame 👇👇👇
        // printf(">>> Test case '%s' GPU 计算已完成!sleep 10秒,请立即点击 Xcode 📷 Capture GPU Frame 按钮!<<<\n", test_name.c_str());
        // sleep(3); // 给你 5 秒钟点击 Capture 按钮
        // // 👆👆👆 关键修复结束 👆👆👆
        // printf(">>> 程序继续运行... <<<\n");

        // 获取输出结果      
        std::vector<float> ggml_output(test_case.elements);      
        ggml_backend_tensor_get(output, ggml_output.data(), 0, ggml_output.size() * sizeof(float));      
              
        // 计算精度指标      
        PrecisionMetrics metrics = calculate_metrics(ggml_output, expected_output, config.tolerance);        
              
        // 输出结果  
        std::cout << "\n=== Test Case: " << test_name << " ===" << std::endl;      
        std::cout << "GGML Shape: [" << test_case.shape[0] << ", " << test_case.shape[1]       
                  << ", " << test_case.shape[2] << ", " << test_case.shape[3] << "]" << std::endl;      
        std::cout << "EPS: " << test_case.eps << std::endl;    
        std::cout << "Weight Elements: " << test_case.weight_elements << std::endl;    
        std::cout << "Gate Elements: " << test_case.gate_elements << std::endl;  
        std::cout << std::fixed << std::setprecision(8);      
        std::cout << "Max Absolute Error: " << metrics.max_abs_error << std::endl;      
        std::cout << "Mean Absolute Error: " << metrics.mean_abs_error << std::endl;      
        std::cout << "RMSE: " << metrics.rmse << std::endl;      
        std::cout << "Relative Error: " << metrics.relative_error << std::endl;      
        std::cout << "Status: " << (metrics.passed ? "PASS" : "FAIL") << std::endl;      
              
        // 清理资源      
        ggml_backend_buffer_free(buffer);      
        ggml_free(ctx);      
              
        return metrics.passed;      
              
    } catch (const std::exception& e) {      
        std::cerr << "Error in test case " << test_name << ": " << e.what() << std::endl;      
        return false;      
    }      
}      
  
TestConfig parse_args(int argc, char* argv[]) {      
    TestConfig config;      
          
    const char* env_data_path = std::getenv("GGML_TEST_DATA_PATH");    
    if (env_data_path) {    
        config.data_path = env_data_path;    
        if (!config.data_path.empty() && config.data_path.back() != '/') {    
            config.data_path += '/';    
        }    
    }  
  
    for (int i = 1; i < argc; i++) {      
        if (strcmp(argv[i], "--data-path") == 0 && i + 1 < argc) {      
            config.data_path = argv[++i];      
            if (!config.data_path.empty() && config.data_path.back() != '/') {      
                config.data_path += '/';      
            }      
        } else if (strcmp(argv[i], "--tolerance") == 0 && i + 1 < argc) {      
            config.tolerance = std::stod(argv[++i]);      
        } else if (strcmp(argv[i], "--test-case") == 0 && i + 1 < argc) {  
            // 允许指定特定的测试用例  
            config.test_cases.clear();  
            config.test_cases.push_back(argv[++i]);  
        } else if (strcmp(argv[i], "--help") == 0 || strcmp(argv[i], "-h") == 0) {  
            std::cout << "Usage: " << argv[0] << " [options]\n";  
            std::cout << "Options:\n";  
            std::cout << "  --data-path PATH    Path to test data files (default: ./)\n";  
            std::cout << "  --tolerance VALUE   Error tolerance for tests (default: 1e-5)\n";  
            std::cout << "  --test-case NAME    Run specific test case (default: all)\n";  
            std::cout << "  --help, -h          Show this help message\n";  
            exit(0);  
        } else {  
            std::cerr << "Unknown argument: " << argv[i] << std::endl;  
            std::cerr << "Use --help for usage information" << std::endl;  
            exit(1);  
        }  
    }      
          
    return config;      
}

int main(int argc, char* argv[]) {      
    TestConfig config = parse_args(argc, argv);    
        
    // 初始化Metal后端  
    ggml_backend_t backend = ggml_backend_metal_init();      
    if (!backend) {      
        std::cerr << "Failed to initialize Metal backend" << std::endl;      
        return 1;      
    }      
    
    int passed = 0;      
    int total = config.test_cases.size();    
          
    std::cout << "Running GGML Metal vs PyTorch Qwen3NextRMSNormGated Precision Tests" << std::endl;    
    std::cout << "Data Path: " << config.data_path << std::endl;    
    std::cout << "Tolerance: " << config.tolerance << std::endl;    
    std::cout << "Test Cases: " << total << std::endl;  
    std::cout << "Note: PyTorch dimensions are automatically converted to GGML order" << std::endl;    
    std::cout << "================================================" << std::endl;    
          
    // const char* metal_path = getenv("GGML_METAL_PATH");
    // if (metal_path) {
    //     printf(">>> GGML_METAL_PATH = %s\n", metal_path);
    // } else {
    //     printf(">>> GGML_METAL_PATH not set!\n");
    // }


    // 运行所有测试用例  
    for (const auto& test_name : config.test_cases) {  
        if (run_test_case(test_name, backend, config)) {    
            passed++;    
        }    
    }    
          
    // 输出测试总结  
    std::cout << "\n=== Test Summary ===" << std::endl;    
    std::cout << "Passed: " << passed << "/" << total << std::endl;    
    std::cout << "Success Rate: " << (100.0 * passed / total) << "%" << std::endl;     
      
    if (passed == total) {  
        std::cout << "All tests PASSED!" << std::endl;  
    } else {  
        std::cout << "Some tests FAILED!" << std::endl;  
    }  
    

    ggml_backend_free(backend);      
          
    return (passed == total) ? 0 : 1;      
}
previous result 截屏2025-09-27 21 50 56

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning testing Everything test related

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant