-
Notifications
You must be signed in to change notification settings - Fork 749
Add Qwen3 0.6B, 1.7B, and 4B #10539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add Qwen3 0.6B, 1.7B, and 4B #10539
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
803ff1d
Add Qwen3 0.6B
jackzhxng 507dbc0
1.7B and 4B
jackzhxng da0d159
1.7b and 4b configs
jackzhxng 5a02439
readme
jackzhxng 17989cc
Update top level readme
jackzhxng f9b733f
Merge branch 'main' into jz/add-qwen3
jackzhxng fa7ff5d
Update readme
jackzhxng c5dba06
qk norm before rope arg
jackzhxng 626a0f0
Merge branch 'main' into jz/add-qwen3
jackzhxng File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| { | ||
| "dim": 1024, | ||
| "ffn_dim_multiplier": 1, | ||
| "hidden_dim": 3072, | ||
| "n_heads": 16, | ||
| "head_dim": 128, | ||
| "n_kv_heads": 8, | ||
| "n_layers": 28, | ||
| "norm_eps": 1e-06, | ||
| "rope_theta": 1000000.0, | ||
| "use_scaled_rope": false, | ||
| "vocab_size": 151936, | ||
| "use_hf_rope": true, | ||
| "attention_qkv_bias": false, | ||
| "use_qk_norm": true, | ||
| "qk_norm_before_rope": true | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| { | ||
| "dim": 2048, | ||
| "ffn_dim_multiplier": 1, | ||
| "hidden_dim": 6144, | ||
| "n_heads": 16, | ||
| "head_dim": 128, | ||
| "n_kv_heads": 8, | ||
| "n_layers": 28, | ||
| "norm_eps": 1e-06, | ||
| "rope_theta": 1000000.0, | ||
| "use_scaled_rope": false, | ||
| "vocab_size": 151936, | ||
| "use_hf_rope": true, | ||
| "attention_qkv_bias": false, | ||
| "use_qk_norm": true, | ||
| "qk_norm_before_rope": true | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,17 @@ | ||
| { | ||
| "dim": 2560, | ||
| "ffn_dim_multiplier": 1, | ||
| "hidden_dim": 9728, | ||
| "n_heads": 32, | ||
| "head_dim": 128, | ||
| "n_kv_heads": 8, | ||
| "n_layers": 36, | ||
| "norm_eps": 1e-06, | ||
| "rope_theta": 1000000.0, | ||
| "use_scaled_rope": false, | ||
| "vocab_size": 151936, | ||
| "use_hf_rope": true, | ||
| "attention_qkv_bias": false, | ||
| "use_qk_norm": true, | ||
| "qk_norm_before_repo": true | ||
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,85 @@ | ||
| ## Summary | ||
| Qwen 3 is the latest iteration of the Qwen series of large language models (LLMs) developed by Alibaba. Edge-sized Qwen3 model variations (0.6B, 1.7B, and 4B) are currently supported . | ||
|
|
||
| ## Instructions | ||
|
|
||
| Qwen 3 uses the same example code as our optimized Llama model, while the checkpoint, model params, and tokenizer are different. Please see the [Llama README page](../llama/README.md) for details. | ||
|
|
||
| All commands for exporting and running Llama on various backends should also be applicable to Qwen 3, by swapping the following args: | ||
| ``` | ||
| --model [qwen3-0.6b,qwen3-1_7b,qwen3-4b] | ||
| --params [examples/models/qwen3/0_6b_config.json,examples/models/qwen3/1_7b_config.json,examples/models/qwen3/4b_config.json] | ||
| ``` | ||
|
|
||
| ### Example export | ||
| Here is a basic example for exporting Qwen 3, although please refer to the Llama README's [Step 2: Prepare model](../llama/README.md#step-2-prepare-model) for more advanced usage. | ||
|
|
||
| Export 0.6b to XNNPack, quantized with 8da4w: | ||
| ``` | ||
| python -m examples.models.llama.export_llama \ | ||
| --model qwen3-0_6b \ | ||
| --params examples/models/qwen3/0_6b_config.json \ | ||
| -kv \ | ||
| --use_sdpa_with_kv_cache \ | ||
| -d fp32 \ | ||
| -X \ | ||
| --xnnpack-extended-ops \ | ||
| -qmode 8da4w | ||
| --output_name="qwen3-0_6b.pte" \ | ||
| --verbose | ||
| ``` | ||
|
|
||
| Export 1.7b to XNNPack, quantized with 8da4w: | ||
| ``` | ||
| python -m examples.models.llama.export_llama \ | ||
| --model qwen3-1_7b \ | ||
| --params examples/models/qwen3/1_7b_config.json \ | ||
| -kv \ | ||
| --use_sdpa_with_kv_cache \ | ||
| -d fp32 \ | ||
| -X \ | ||
| --xnnpack-extended-ops \ | ||
| -qmode 8da4w | ||
| --output_name="qwen3-1_7b.pte" \ | ||
| --verbose | ||
| ``` | ||
|
|
||
| Export 4b to XNNPack, quantized with 8da4w: | ||
| ``` | ||
| python -m examples.models.llama.export_llama \ | ||
| --model qwen3-4b \ | ||
| --params examples/models/qwen3/4b_config.json \ | ||
| -kv \ | ||
| --use_sdpa_with_kv_cache \ | ||
| -d fp32 \ | ||
| -X \ | ||
| --xnnpack-extended-ops \ | ||
| -qmode 8da4w | ||
| --output_name="qwen3-4b.pte" \ | ||
| --verbose | ||
| ``` | ||
|
|
||
| ### Example run | ||
| With ExecuTorch pybindings: | ||
| ``` | ||
| python -m examples.models.llama.runner.native | ||
| --model qwen3-0_6b \ | ||
| --pte qwen3-0_6b.pte \ | ||
| --tokenizer ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer.json \ | ||
| --tokenizer_config ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer_config.json \ | ||
| --prompt "Who is the president of the US?" \ | ||
| --params examples/models/qwen3/0_6b_config.json \ | ||
| --max_len 128 \ | ||
| -kv \ | ||
| --temperature 0.6 | ||
| ``` | ||
|
|
||
| With ExecuTorch's sample c++ runner (see the Llama README's [Step 3: Run on your computer to validate](../llama/README.md#step-3-run-on-your-computer-to-validate) to build the runner): | ||
| ``` | ||
| cmake-out/examples/models/llama/llama_main | ||
| --model_path qwen3-0_6b.pte | ||
| --tokenizer_path ~/.cache/huggingface/hub/models--Qwen--Qwen3-0.6B/snapshots/a9c98e602b9d36d2a2f7ba1eb0f5f31e4e8e5143/tokenizer.json | ||
| --prompt="Who is the president of the US?" | ||
| ``` | ||
|
|
||
| To run the model on an example iOS or Android app, see the Llama README's [Step 5: Build Mobile apps](../llama/README.md#step-5-build-mobile-apps) section. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,16 @@ | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| from executorch.examples.models.llama.model import Llama2Model | ||
| from executorch.examples.models.qwen3.convert_weights import convert_weights | ||
|
|
||
|
|
||
| class Qwen3Model(Llama2Model): | ||
| def __init__(self, **kwargs): | ||
| super().__init__(**kwargs) | ||
|
|
||
|
|
||
| __all__ = [ | ||
| "Qwen3Model", | ||
| "convert_weights", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,115 @@ | ||
| import argparse | ||
|
|
||
| import json | ||
| import os | ||
| from typing import Dict | ||
|
|
||
| import torch | ||
| from safetensors.torch import load_file | ||
|
|
||
| from torchtune.models.convert_weights import get_mapped_key | ||
|
|
||
| # Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings. | ||
| _QWEN_3_FROM_META = { | ||
| "tok_embeddings.weight": "model.embed_tokens.weight", | ||
| "norm.weight": "model.norm.weight", | ||
| "layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight", | ||
| "layers.{}.attention.k_norm_fn.weight": "model.layers.{}.self_attn.k_norm.weight", | ||
| "layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight", | ||
| "layers.{}.attention.q_norm_fn.weight": "model.layers.{}.self_attn.q_norm.weight", | ||
| "layers.{}.attention.wv.weight": "model.layers.{}.self_attn.v_proj.weight", | ||
| "layers.{}.attention.wo.weight": "model.layers.{}.self_attn.o_proj.weight", | ||
| "layers.{}.attention_norm.weight": "model.layers.{}.input_layernorm.weight", | ||
| "layers.{}.ffn_norm.weight": "model.layers.{}.post_attention_layernorm.weight", | ||
| # Note: gate_proj and up_proj are reversed, usually w1 is the up_proj, | ||
| # w2 is the gate_proj, and activation is applied on the up_proj, but since | ||
| # Qwen3 applies activation on the gate_proj, we just swap the gate_proj | ||
| # and up_proj in the checkpoint itself as a hack. | ||
| "layers.{}.feed_forward.w1.weight": "model.layers.{}.mlp.gate_proj.weight", | ||
| "layers.{}.feed_forward.w2.weight": "model.layers.{}.mlp.down_proj.weight", | ||
| "layers.{}.feed_forward.w3.weight": "model.layers.{}.mlp.up_proj.weight", | ||
| } | ||
|
|
||
|
|
||
| def qwen_3_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: | ||
| """ | ||
| Convert a state dict from torchtune's format to Meta's format. This function | ||
| doesn't handle any sharding or splitting of state dicts. It follows the | ||
| state_dict IN -> state_dict OUT pattern. | ||
|
|
||
| Args: | ||
| state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format. | ||
|
|
||
| Returns: | ||
| Dict[str, torch.Tensor]: State dict in Meta's format. | ||
| """ | ||
| converted_state_dict = {} | ||
| inverted_mapping_dict = {v: k for k, v in _QWEN_3_FROM_META.items()} | ||
|
|
||
| for key, value in state_dict.items(): | ||
| # Tied embeddings for 0.6b and 4b models. | ||
| if key == "lm_head.weight": | ||
| continue | ||
| new_key = get_mapped_key(key, inverted_mapping_dict) | ||
| converted_state_dict[new_key] = value | ||
|
|
||
| converted_state_dict["output.weight"] = converted_state_dict[ | ||
| "tok_embeddings.weight" | ||
| ] | ||
|
|
||
| return converted_state_dict | ||
|
|
||
|
|
||
| def load_checkpoint(input_dir: str) -> Dict: | ||
| index_path = os.path.join(input_dir, "model.safetensors.index.json") | ||
| if os.path.exists(index_path): | ||
| # Sharded checkpoint. | ||
| with open(index_path, "r") as f: | ||
| index = json.load(f) | ||
| weight_map = index["weight_map"] | ||
| checkpoint_shards = sorted(set(weight_map.values())) | ||
|
|
||
| # Load all the shards into memory | ||
| shard_to_weights = {} | ||
| for shard in checkpoint_shards: | ||
| shard_to_weights[shard] = load_file(os.path.join(input_dir, shard)) | ||
|
|
||
| # Merge tensors into consolidated state dict. | ||
| merged_state_dict = {} | ||
| for weight_name, shard in weight_map.items(): | ||
| tensor = shard_to_weights[shard][weight_name] | ||
| merged_state_dict[weight_name] = tensor | ||
| return merged_state_dict | ||
| else: | ||
| # Single checkpoint. | ||
| state_dict = load_file(os.path.join(input_dir, "model.safetensors")) | ||
| return state_dict | ||
|
|
||
|
|
||
| def convert_weights(input_dir: str, output_file: str) -> None: | ||
| print("Loading checkpoint...") | ||
| sd = load_checkpoint(input_dir) | ||
| print("Converting checkpoint...") | ||
| sd = qwen_3_tune_to_meta(sd) | ||
| print("Saving checkpoint...") | ||
| torch.save(sd, output_file) | ||
| print("Done.") | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser( | ||
| description="Convert Qwen3 weights to Meta format." | ||
| ) | ||
| parser.add_argument( | ||
| "input_dir", | ||
| type=str, | ||
| help="Path to directory containing checkpoint files", | ||
| ) | ||
| parser.add_argument("output", type=str, help="Path to the output checkpoint") | ||
|
|
||
| args = parser.parse_args() | ||
| convert_weights(args.input_dir, args.output) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.