Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/models/llama/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,17 +243,17 @@ def forward(
k = k.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
v = v.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)

if self.use_qk_norm:
q = self.q_norm_fn(q)
k = self.k_norm_fn(k)

# RoPE relative positional embeddings
q, k = self.rope.forward(q, k, freqs_cos, freqs_sin)

q = q.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
k = k.transpose(1, 2)
v = v.transpose(1, 2)

if self.use_qk_norm:
q = self.q_norm_fn(q)
k = self.k_norm_fn(k)

if self.use_kv_cache:
assert input_pos is not None
k, v = self.kv_cache.update(input_pos, k, v)
Expand Down
10 changes: 10 additions & 0 deletions examples/models/llama/export_llama_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@
"llama3_2",
"static_llama",
"qwen2_5",
"qwen3-0_6b",
"qwen3-1_7b",
"qwen3-4b",
"phi_4_mini",
"smollm2",
]
Expand All @@ -108,6 +111,9 @@
"qwen2_5": "Qwen/Qwen2.5-1.5B",
"phi_4_mini": "microsoft/Phi-4-mini-instruct",
"smollm2": "HuggingFaceTB/SmolLM-135M",
"qwen3-0_6b": "Qwen/Qwen3-0.6B",
"qwen3-1_7b": "Qwen/Qwen3-1.7B",
"qwen3-4b": "Qwen/Qwen3-4B",
}


Expand Down Expand Up @@ -544,6 +550,10 @@ def export_llama(args) -> str:
from executorch.examples.models.qwen2_5 import ( # pyre-ignore[21]
convert_weights,
)
elif args.model.startswith("qwen3"):
from executorch.examples.models.qwen3 import ( # pyre-ignore[21]
convert_weights,
)
elif args.model == "phi_4_mini":
from executorch.examples.models.phi_4_mini import ( # pyre-ignore[21]
convert_weights,
Expand Down
16 changes: 16 additions & 0 deletions examples/models/qwen3/0_6b_config.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
{
"dim": 1024,
"ffn_dim_multiplier": 1,
"hidden_dim": 3072,
"n_heads": 16,
"head_dim": 128,
"n_kv_heads": 8,
"n_layers": 28,
"norm_eps": 1e-06,
"rope_theta": 1000000.0,
"use_scaled_rope": false,
"vocab_size": 151936,
"use_hf_rope": true,
"attention_qkv_bias": false,
"use_qk_norm": true
}
16 changes: 16 additions & 0 deletions examples/models/qwen3/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

from executorch.examples.models.llama.model import Llama2Model
from executorch.examples.models.qwen3.convert_weights import convert_weights


class Qwen3Model(Llama2Model):
def __init__(self, **kwargs):
super().__init__(**kwargs)


__all__ = [
"Qwen3Model",
"convert_weights",
]
115 changes: 115 additions & 0 deletions examples/models/qwen3/convert_weights.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
import argparse

import json
import os
from typing import Dict

import torch
from safetensors.torch import load_file

from torchtune.models.convert_weights import get_mapped_key

# Standard _FROM_META weight mapping of Meta weights to TorchTune + additional bias weight mappings.
_QWEN_3_FROM_META = {
"tok_embeddings.weight": "model.embed_tokens.weight",
"norm.weight": "model.norm.weight",
"layers.{}.attention.wk.weight": "model.layers.{}.self_attn.k_proj.weight",
"layers.{}.attention.k_norm_fn.weight": "model.layers.{}.self_attn.k_norm.weight",
"layers.{}.attention.wq.weight": "model.layers.{}.self_attn.q_proj.weight",
"layers.{}.attention.q_norm_fn.weight": "model.layers.{}.self_attn.q_norm.weight",
"layers.{}.attention.wv.weight": "model.layers.{}.self_attn.v_proj.weight",
"layers.{}.attention.wo.weight": "model.layers.{}.self_attn.o_proj.weight",
"layers.{}.attention_norm.weight": "model.layers.{}.input_layernorm.weight",
"layers.{}.ffn_norm.weight": "model.layers.{}.post_attention_layernorm.weight",
# Note: gate_proj and up_proj are reversed, usually w1 is the up_proj,
# w2 is the gate_proj, and activation is applied on the up_proj, but since
# Qwen3 applies activation on the gate_proj, we just swap the gate_proj
# and up_proj in the checkpoint itself as a hack.
"layers.{}.feed_forward.w1.weight": "model.layers.{}.mlp.gate_proj.weight",
"layers.{}.feed_forward.w2.weight": "model.layers.{}.mlp.down_proj.weight",
"layers.{}.feed_forward.w3.weight": "model.layers.{}.mlp.up_proj.weight",
}


def qwen_3_tune_to_meta(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Convert a state dict from torchtune's format to Meta's format. This function
doesn't handle any sharding or splitting of state dicts. It follows the
state_dict IN -> state_dict OUT pattern.

Args:
state_dict (Dict[str, torch.Tensor]): State dict in torchtune's format.

Returns:
Dict[str, torch.Tensor]: State dict in Meta's format.
"""
converted_state_dict = {}
inverted_mapping_dict = {v: k for k, v in _QWEN_3_FROM_META.items()}

for key, value in state_dict.items():
# Tied embeddings for 0.6b and 4b models.
if key == "lm_head.weight":
continue
new_key = get_mapped_key(key, inverted_mapping_dict)
converted_state_dict[new_key] = value

converted_state_dict["output.weight"] = converted_state_dict[
"tok_embeddings.weight"
]

return converted_state_dict


def load_checkpoint(input_dir: str) -> Dict:
index_path = os.path.join(input_dir, "model.safetensors.index.json")
if os.path.exists(index_path):
# Sharded checkpoint.
with open(index_path, "r") as f:
index = json.load(f)
weight_map = index["weight_map"]
checkpoint_shards = sorted(set(weight_map.values()))

# Load all the shards into memory
shard_to_weights = {}
for shard in checkpoint_shards:
shard_to_weights[shard] = load_file(os.path.join(input_dir, shard))

# Merge tensors into consolidated state dict.
merged_state_dict = {}
for weight_name, shard in weight_map.items():
tensor = shard_to_weights[shard][weight_name]
merged_state_dict[weight_name] = tensor
return merged_state_dict
else:
# Single checkpoint.
state_dict = load_file(os.path.join(input_dir, "model.safetensors"))
return state_dict


def convert_weights(input_dir: str, output_file: str) -> None:
print("Loading checkpoint...")
sd = load_checkpoint(input_dir)
print("Converting checkpoint...")
sd = qwen_3_tune_to_meta(sd)
print("Saving checkpoint...")
torch.save(sd, output_file)
print("Done.")


def main():
parser = argparse.ArgumentParser(
description="Convert Qwen3 weights to Meta format."
)
parser.add_argument(
"input_dir",
type=str,
help="Path to directory containing checkpoint files",
)
parser.add_argument("output", type=str, help="Path to the output checkpoint")

args = parser.parse_args()
convert_weights(args.input_dir, args.output)


if __name__ == "__main__":
main()
Loading