Skip to content

Commit b92970c

Browse files
Add OpenAI's gpt-oss to ONNX Runtime GenAI (microsoft#1678)
### Description This PR adds [OpenAI's gpt-oss](https://huggingface.co/openai/gpt-oss-20b) to ONNX Runtime GenAI. ### Usage You can download and run optimized and quantized ONNX versions of `gpt-oss-20b` from Hugging Face [here](https://huggingface.co/onnxruntime/gpt-oss-20b-onnx). If you want to build your own optimized and quantized ONNX variants, here is how you can do that. INT4-only model: ```bash # INT4 CPU + INT4 QMoE op (k-quant mixed) $ python builder.py -m openai/gpt-oss-20b -o ./gpt_oss_int4_cpu -p int4 -e cpu -c ./cache_dir --extra_options int4_algo_config=k_quant_mixed # INT4 CPU + INT4 QMoE op (default) $ python builder.py -m openai/gpt-oss-20b -o ./gpt_oss_int4_cpu -p int4 -e cpu -c ./cache_dir --extra_options int4_op_types_to_quantize=MatMul/Gather # INT4 CUDA + INT4 QMoE op (k-quant mixed) $ python builder.py -m openai/gpt-oss-20b -o ./gpt_oss_int4_cuda -p int4 -e cuda -c ./cache_dir --extra_options int4_algo_config=k_quant_mixed # INT4 CUDA + INT4 QMoE op (default) $ python builder.py -m openai/gpt-oss-20b -o ./gpt_oss_int4_cuda -p int4 -e cuda -c ./cache_dir --extra_options int4_op_types_to_quantize=MatMul/Gather # INT4 for all other execution providers (default) $ python builder.py -m openai/gpt-oss-20b -o ./gpt_oss_int4_ep -p int4 -e execution_provider -c ./cache_dir --extra_options int4_op_types_to_quantize=MatMul/Gather ``` INT4 + INT8 model: ```bash # INT4 CPU + INT8 QMoE op (k-quant mixed) $ python builder.py -m openai/gpt-oss-20b -o ./gpt_oss_int4_cpu -p int4 -e cpu -c ./cache_dir --extra_options int4_algo_config=k_quant_mixed use_8bits_moe=true # INT4 CPU + INT8 QMoE op (default) $ python builder.py -m openai/gpt-oss-20b -o ./gpt_oss_int4_cpu -p int4 -e cpu -c ./cache_dir --extra_options int4_op_types_to_quantize=MatMul/Gather use_8bits_moe=true # INT4 CUDA + INT8 QMoE op (k-quant mixed) $ python builder.py -m openai/gpt-oss-20b -o ./gpt_oss_int4_cuda -p int4 -e cuda -c ./cache_dir --extra_options int4_algo_config=k_quant_mixed use_8bits_moe=true # INT4 CUDA + INT8 QMoE op (default) $ python builder.py -m openai/gpt-oss-20b -o ./gpt_oss_int4_cuda -p int4 -e cuda -c ./cache_dir --extra_options int4_op_types_to_quantize=MatMul/Gather use_8bits_moe=true ``` ### K-Quant Mixed K-Quant mixed is an algorithm applied to quantize the weights on only `MatMul` ops. Any weights in `Gather` ops (e.g. embedding, weights in unfused MoE subgraph, etc) will not be quantized. To quantize both `MatMul` and `Gather`, you can remove `int4_algo_config` from the `extra_options` to use the default RTN algorithm (which can be applied on both ops) and add `int4_op_types_to_quantize=MatMul/Gather` to the `extra_options`. ### Quantization with MoE Subgraph Currently, the `MoE` and `QMoE` ops are only supported by the CPU EP and the CUDA EP. The quantization performed on the `QMoE` op can be controlled by `use_8bits_moe` in the `extra_options`. If provided and set to `true`, the `QMoE` op will be quantized to INT8. If not provided or provided but set to `false`, the `QMoE` op will be quantized to INT4. When the `MoE` or `QMoE` ops are not in the ONNX model (e.g. for other EPs), the unfused MoE subgraph is created. The weights for the MoE are stored in `Gather` nodes. Thus, it is important to - a) provide `int4_op_types_to_quantize=MatMul/Gather` when selecting other EPs for all weights in the model to be quantized to INT4 (otherwise the resulting ONNX model will be very large since it contains unquantized weights) - b) use the default RTN algorithm to quantize the `Gather` ops (since the other algorithms do not quantize that op) Following the flow that was established when the Phi-3.5 MoE model was onboarded into the model builder, [TensorRT-LLM](https://github.com/nvidia/TensorRT-LLM) is used to perform channel-wise quantization on the QMoE op. Please note that pre-built Python wheels for TensorRT-LLM are only [available](https://pypi.org/project/tensorrt-llm/) for Python 3.10 and Python 3.12. Please run `python builder.py --help` for more details. ### Run Inference To run inference, you can follow the `model-chat.py` example. ```bash $ python model-chat.py -m /path/to/folder/containing/gpt_oss/onnx/model/ -e execution_provider --top_k 50 --do_sample ``` ### Motivation and Context "As part of today’s release, Microsoft is also bringing GPU-optimized versions of the gpt-oss-20b model to Windows devices. Powered by ONNX Runtime, these models support local inference and are available through Foundry Local and the AI Toolkit for VS Code, making it easier for Windows developers to build with open models." -[OpenAI](https://openai.com/index/introducing-gpt-oss/) "Today, **gpt‑oss-120b** and **gpt‑oss-20b** are available on Azure AI Foundry. gpt‑oss-20b is also available on Windows AI Foundry and will be coming soon on MacOS via Foundry Local." -[Microsoft](https://aka.ms/OAIOSSfoundryblog) ### QMoE Op Spec Here are the tensor shapes for the `qweight` and `scales` tensors after stacking on the batch dimension and [before serializing to disk](https://github.com/microsoft/onnxruntime-genai/blob/1b8b6fd6d4143d2dcd577db8a84daf03e9cd1f1f/src/python/py/models/builder.py#L4152). For INT4 (i.e. `expert_weight_bits = 4`): <img width="606" height="276" alt="image" src="https://github.com/user-attachments/assets/05d37bc4-15a3-4ded-a033-aae7317d1571" /> For INT8 (i.e. `expert_weight_bits = 8`): <img width="608" height="282" alt="image" src="https://github.com/user-attachments/assets/02621373-d251-4843-bb61-f5f3cf24f945" /> The [model's configuration](https://huggingface.co/openai/gpt-oss-20b/blob/main/config.json) says that: - `num_experts = 32` - `inter_size = 2880` - `hidden_size = 2880` #### Old Op Spec The [old op spec](https://github.com/microsoft/onnxruntime/blob/3b10e44769db8320e033a6102c33768f1bc69456/onnxruntime/core/graph/contrib_ops/contrib_defs.cc#L1438-L1443) says the following. ``` .Input(2, "fc1_experts_weights", "3D input tensor with shape (num_experts, hidden_size, inter_size) " "or (num_experts, hidden_size, inter_size / 2). For swiglu, shape can be (num_experts, hidden_size, 2 * inter_size) or (num_experts, hidden_size, inter_size).", "T1") .Input(3, "fc1_scales", "2D input tensor with shape (num_experts, inter_size), or (num_experts, 2 * inter_size) for swiglu", "T2") .Input(5, "fc2_experts_weights", "3D input tensor with shape (num_experts, inter_size, hidden_size) " "or (num_experts, inter_size, hidden_size / 2)", "T1") .Input(6, "fc2_scales", "2D input tensor with shape (num_experts, hidden_size)", "T2") ``` FC1 weights: - INT4: `(num_experts, hidden_size, inter_size)` = (32, 2880, 2880) - INT8: `(num_experts, hidden_size, 2 * inter_size)` = (32, 2880, 2 * 2880) = (32, 2880, 5760) - This matches the expected shape. FC1 scales: - `(num_experts, inter_size)` - INT4: (32, 2880) - INT8: (32, 2880) - This matches the expected shape. FC2 weights: - INT4: `(num_experts, inter_size, hidden_size / 2)` = (32, 2880, 2880 / 2) = (32, 2880, 1440) - INT8: `(num_experts, inter_size, hidden_size)` = (32, 2880, 2880) = (32, 2880, 2880) - This matches the expected shape. FC2 scales: - `(num_experts, hidden_size)` - INT4: (32, 2880) - INT8: (32, 2880) - This matches the expected shape. #### New Op Spec The [new op spec](https://github.com/microsoft/onnxruntime/blob/a1a5baec454445864ad0880ec18d8a5d1a83b48c/onnxruntime/core/graph/contrib_ops/contrib_defs.cc#L1500-L1509) says the following. ``` .Input(2, "fc1_experts_weights", "3D tensor with shape (num_experts, fusion_size * inter_size, hidden_size / pack_size), " "The fusion_size is 2 for fused swiglu, or 1 otherwise. The pack_size is 8 / expert_weight_bits.", "T1") .Input(3, "fc1_scales", "2D tensor with shape (num_experts, fusion_size * inter_size), or " "3D tensor with shape (num_experts, fusion_size * inter_size, hidden_size / block_size) when block_size is provided.", "T2") .Input(5, "fc2_experts_weights", "3D tensor with shape (num_experts, hidden_size, inter_size / pack_size)", "T1") .Input(6, "fc2_scales", "2D tensor with shape (num_experts, hidden_size), or " "3D tensor with shape (num_experts, hidden_size, inter_size / block_size) when block_size is provided.", "T2") ``` FC1 weights: - `(num_experts, fusion_size * inter_size, hidden_size / pack_size)` - INT4: (32, 2 * 2880, 2880 / (8 / 4)) = (32, 5760, 1440) - INT8: (32, 2 * 2880, 2880 / (8 / 8)) = (32, 5760, 2880) - This can be obtained by performing a `torch.view()` operation to modify the existing tensor's shape without changing the underlying layout (e.g. `tensor.view(32, 5760, -1)`). After the view operation, this matches the expected shape. This means that `fusion_size = 2`. FC1 scales: - `(num_experts, fusion_size * inter_size)` - INT4: (32, 2 * 2880) = (32, 5760) - INT8: (32, 2 * 2880) = (32, 5760) - This matches the expected shape. FC2 weights: - `(num_experts, hidden_size, inter_size / pack_size)` - INT4: (32, 2880, 2880 / 2) = (32, 2880, 1440) - INT8: (32, 2880, 2880 / 1) = (32, 2880, 2880) - This matches the expected shape. FC2 scales: - `(num_experts, hidden_size)` - INT4: (32, 2880) - INT8: (32, 2880) - This matches the expected shape.
1 parent 625815e commit b92970c

File tree

4 files changed

+727
-236
lines changed

4 files changed

+727
-236
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ See documentation at https://onnxruntime.ai/docs/genai.
1616

1717
|Support matrix|Supported now|Under development|On the roadmap|
1818
| -------------- | ------------- | ----------------- | -------------- |
19-
| Model architectures | AMD OLMo <br/> ChatGLM <br/> DeepSeek <br/> ERNIE 4.5 <br/> Gemma <br/> Granite <br/> Llama * <br/> Mistral + <br/> Nemotron <br/> Phi (language + vision) <br/> Qwen <br/> SmolLM3 | Whisper | Stable diffusion |
19+
| Model architectures | AMD OLMo <br/> ChatGLM <br/> DeepSeek <br/> ERNIE 4.5 <br/> Gemma <br/> gpt-oss <br/> Granite <br/> Llama * <br/> Mistral + <br/> Nemotron <br/> Phi (language + vision) <br/> Qwen <br/> SmolLM3 | Whisper | Stable diffusion |
2020
|API| Python <br/>C# <br/>C/C++ <br/> Java ^ |Objective-C||
2121
|Platform| Linux <br/> Windows <br/>Mac ^ <br/>Android ^ ||iOS |||
2222
|Architecture|x86 <br/> x64 <br/> Arm64 ~ ||||

src/models/model_type.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ namespace Generators {
1212
struct ModelType {
1313
inline static bool IsLLM(const std::string& model_type) {
1414
// Large-language model (LLM)
15-
static constexpr std::array<std::string_view, 19> LLM = {"chatglm", "decoder", "ernie4_5", "gemma", "gemma2", "gemma3_text", "gpt2", "granite", "llama", "mistral", "nemotron", "olmo", "phi", "phimoe", "phi3", "phi3small", "qwen2", "qwen3", "smollm3"};
15+
static constexpr std::array<std::string_view, 20> LLM = {"chatglm", "decoder", "ernie4_5", "gemma", "gemma2", "gemma3_text", "gpt2", "gptoss", "granite", "llama", "mistral", "nemotron", "olmo", "phi", "phimoe", "phi3", "phi3small", "qwen2", "qwen3", "smollm3"};
1616
return std::find(LLM.begin(), LLM.end(), model_type) != LLM.end();
1717
}
1818

src/python/py/models/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ The tool currently supports the following model architectures.
3636
- DeepSeek
3737
- ERNIE 4.5
3838
- Gemma
39+
- gpt-oss
3940
- Granite
4041
- Llama
4142
- Mistral

0 commit comments

Comments
 (0)