Skip to content

Commit 1602e96

Browse files
ysjprojectsshijie.yupre-commit-ci[bot]
authored
Qwen3 Dense (#2044)
Co-authored-by: shijie.yu <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent f021d88 commit 1602e96

File tree

8 files changed

+339
-2
lines changed

8 files changed

+339
-2
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,7 @@ Every model is written from scratch to maximize performance and remove layers of
146146
| Qwen2.5 Math | 1.5B, 7B, 72B | Alibaba Group | [An, Yang et al. 2024](https://arxiv.org/abs/2409.12122) |
147147
| QwQ | 32B | Alibaba Group | [Qwen Team 2025](https://qwenlm.github.io/blog/qwq-32b/) |
148148
| QwQ-Preview | 32B | Alibaba Group | [Qwen Team 2024](https://qwenlm.github.io/blog/qwq-32b-preview/) |
149+
| Qwen3 | 0.6B, 1.7B, 4B, 8B, 14B, 32B | Alibaba Group | [Qwen Team 2025](https://arxiv.org/abs/2505.09388/) |
149150
| R1 Distill Llama | 8B, 70B | DeepSeek AI | [DeepSeek AI 2025](https://github.com/deepseek-ai/DeepSeek-R1/blob/main/DeepSeek_R1.pdf) |
150151
| SmolLM2 | 135M, 360M, 1.7B | Hugging Face | [Hugging Face 2024](https://github.com/huggingface/smollm) |
151152
| Salamandra | 2B, 7B | Barcelona Supercomputing Centre | [BSC-LTC 2024](https://github.com/BSC-LTC/salamandra) |

litgpt/config.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2460,6 +2460,147 @@ def norm_class(self) -> Type:
24602460

24612461
configs.extend(qwq)
24622462

2463+
qwen_3 = [
2464+
# https://huggingface.co/Qwen/Qwen3-0.6B/blob/main/config.json
2465+
dict(
2466+
name="Qwen3-0.6B{}",
2467+
hf_config=dict(org="Qwen", name="Qwen3-0.6B{}"),
2468+
block_size=40960,
2469+
vocab_size=151643,
2470+
padded_vocab_size=151936,
2471+
n_layer=28,
2472+
n_head=16,
2473+
n_embd=1024,
2474+
n_query_groups=8,
2475+
rotary_percentage=1.0,
2476+
parallel_residual=False,
2477+
bias=False,
2478+
norm_class_name="RMSNorm",
2479+
mlp_class_name="LLaMAMLP",
2480+
intermediate_size=3072,
2481+
norm_eps=1e-6,
2482+
rope_base=1000000,
2483+
head_size=128,
2484+
norm_qk=True,
2485+
),
2486+
# https://huggingface.co/Qwen/Qwen3-1.7B/blob/main/config.json
2487+
dict(
2488+
name="Qwen3-1.7B{}",
2489+
hf_config=dict(org="Qwen", name="Qwen3-1.7B{}"),
2490+
block_size=40960,
2491+
vocab_size=151643,
2492+
padded_vocab_size=151936,
2493+
n_layer=28,
2494+
n_head=16,
2495+
n_embd=2048,
2496+
n_query_groups=8,
2497+
rotary_percentage=1.0,
2498+
parallel_residual=False,
2499+
bias=False,
2500+
norm_class_name="RMSNorm",
2501+
mlp_class_name="LLaMAMLP",
2502+
intermediate_size=6144,
2503+
norm_eps=1e-6,
2504+
rope_base=1000000,
2505+
norm_qk=True,
2506+
),
2507+
# https://huggingface.co/Qwen/Qwen3-4B/blob/main/config.json
2508+
dict(
2509+
name="Qwen3-4B{}",
2510+
hf_config=dict(org="Qwen", name="Qwen3-4B{}"),
2511+
block_size=40960,
2512+
vocab_size=151643,
2513+
padded_vocab_size=151936,
2514+
n_layer=36,
2515+
n_head=32,
2516+
n_embd=2560,
2517+
n_query_groups=8,
2518+
rotary_percentage=1.0,
2519+
parallel_residual=False,
2520+
bias=False,
2521+
norm_class_name="RMSNorm",
2522+
mlp_class_name="LLaMAMLP",
2523+
intermediate_size=9728,
2524+
norm_eps=1e-6,
2525+
rope_base=1000000,
2526+
head_size=128,
2527+
norm_qk=True,
2528+
),
2529+
# https://huggingface.co/Qwen/Qwen3-8B/blob/main/config.json
2530+
dict(
2531+
name="Qwen3-8B{}",
2532+
hf_config=dict(org="Qwen", name="Qwen3-8B{}"),
2533+
block_size=40960,
2534+
vocab_size=151643,
2535+
padded_vocab_size=151936,
2536+
n_layer=36,
2537+
n_head=32,
2538+
n_embd=4096,
2539+
n_query_groups=8,
2540+
rotary_percentage=1.0,
2541+
parallel_residual=False,
2542+
bias=False,
2543+
norm_class_name="RMSNorm",
2544+
mlp_class_name="LLaMAMLP",
2545+
intermediate_size=12288,
2546+
norm_eps=1e-6,
2547+
rope_base=1000000,
2548+
norm_qk=True,
2549+
),
2550+
# https://huggingface.co/Qwen/Qwen3-14B/blob/main/config.json
2551+
dict(
2552+
name="Qwen3-14B{}",
2553+
hf_config=dict(org="Qwen", name="Qwen3-14B{}"),
2554+
block_size=40960,
2555+
vocab_size=151643,
2556+
padded_vocab_size=151936,
2557+
n_layer=40,
2558+
n_head=40,
2559+
n_embd=5120,
2560+
n_query_groups=8,
2561+
rotary_percentage=1.0,
2562+
parallel_residual=False,
2563+
bias=False,
2564+
norm_class_name="RMSNorm",
2565+
mlp_class_name="LLaMAMLP",
2566+
intermediate_size=17408,
2567+
norm_eps=1e-6,
2568+
rope_base=1000000,
2569+
norm_qk=True,
2570+
),
2571+
]
2572+
for c in qwen_3:
2573+
for kind in ("", "-Base"):
2574+
copy = deepcopy(c)
2575+
copy["name"] = c["name"].format(kind)
2576+
copy["hf_config"]["name"] = c["hf_config"]["name"].format(kind)
2577+
configs.append(copy)
2578+
qwen_3_32b = [
2579+
# https://huggingface.co/Qwen/Qwen3-32B/blob/main/config.json
2580+
dict(
2581+
name="Qwen3-32B",
2582+
hf_config=dict(org="Qwen", name="Qwen3-32B"),
2583+
block_size=40960,
2584+
vocab_size=151643,
2585+
padded_vocab_size=151936,
2586+
n_layer=64,
2587+
n_head=64,
2588+
n_embd=5120,
2589+
n_query_groups=8,
2590+
rotary_percentage=1.0,
2591+
parallel_residual=False,
2592+
bias=False,
2593+
norm_class_name="RMSNorm",
2594+
mlp_class_name="LLaMAMLP",
2595+
intermediate_size=25600,
2596+
norm_eps=1e-6,
2597+
rope_base=1000000,
2598+
head_size=128,
2599+
norm_qk=True,
2600+
),
2601+
]
2602+
configs.extend(qwen_3_32b)
2603+
24632604

24642605
#############
24652606
# Salamandra

litgpt/prompts.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str)
345345

346346

347347
class ChatML(PromptStyle):
348-
def __init__(self, system_message: str):
348+
def __init__(self, system_message: Optional[str] = None):
349349
self.system_message = system_message
350350

351351
def apply(self, prompt: str, *, sys_prompt: Optional[str] = None, **kwargs: str) -> str:
@@ -372,6 +372,11 @@ def __init__(self):
372372
)
373373

374374

375+
class Qwen3(ChatML):
376+
def __init__(self):
377+
super().__init__()
378+
379+
375380
class SmolLM2(ChatML):
376381
def __init__(self):
377382
super().__init__("You are a helpful AI assistant named SmolLM, trained by Hugging Face")
@@ -411,6 +416,7 @@ def __init__(self):
411416
"qwen2.5": Qwen2_5,
412417
"qwen2.5-math": Qwen2_5_Math,
413418
"qwq": QwQ,
419+
"qwen3": Qwen3,
414420
"smollm2": SmolLM2,
415421
"salamandra": Salamandra,
416422
}
@@ -463,6 +469,8 @@ def model_name_to_prompt_style(model_name: str) -> PromptStyle:
463469
return Qwen2_5()
464470
if re.search(r"QwQ-.*", model_name):
465471
return QwQ()
472+
if re.search(r"Qwen3-.*", model_name):
473+
return Qwen3()
466474
if re.search(r"SmolLM2.*-Instruct", model_name):
467475
return SmolLM2()
468476
if re.search(r"salamandra-.*-instruct", model_name):

litgpt/scripts/convert_hf_checkpoint.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -533,6 +533,75 @@ def copy_weights_qwen_2_5(
533533
pbar.update(progress_per_file)
534534

535535

536+
def copy_weights_qwen_3(
537+
config: Config,
538+
qkv_weights: Dict[int, List[Optional[NotYetLoadedTensor]]],
539+
state_dict: Dict[str, torch.Tensor],
540+
hf_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
541+
saver: Optional[incremental_save] = None,
542+
dtype: Optional[torch.dtype] = None,
543+
pbar: Optional[tqdm] = None,
544+
progress_per_file: Optional[float] = None,
545+
debug_mode: Optional[bool] = False,
546+
) -> None:
547+
weight_map = {
548+
"model.embed_tokens.weight": "transformer.wte.weight",
549+
"model.layers.{}.input_layernorm.weight": "transformer.h.{}.norm_1.weight",
550+
"model.layers.{}.self_attn.q_proj.weight": None,
551+
"model.layers.{}.self_attn.k_proj.weight": None,
552+
"model.layers.{}.self_attn.v_proj.weight": None,
553+
"model.layers.{}.self_attn.q_norm.weight": "transformer.h.{}.attn.norm_q.weight",
554+
"model.layers.{}.self_attn.k_norm.weight": "transformer.h.{}.attn.norm_k.weight",
555+
"model.layers.{}.self_attn.o_proj.weight": "transformer.h.{}.attn.proj.weight",
556+
"model.layers.{}.post_attention_layernorm.weight": "transformer.h.{}.norm_2.weight",
557+
"model.layers.{}.mlp.gate_proj.weight": "transformer.h.{}.mlp.fc_1.weight",
558+
"model.layers.{}.mlp.up_proj.weight": "transformer.h.{}.mlp.fc_2.weight",
559+
"model.layers.{}.mlp.down_proj.weight": "transformer.h.{}.mlp.proj.weight",
560+
"model.norm.weight": "transformer.ln_f.weight",
561+
"lm_head.weight": "lm_head.weight",
562+
}
563+
564+
if progress_per_file is not None:
565+
progress_per_file = progress_per_file / max(1, len(hf_weights) + len(qkv_weights))
566+
567+
for from_name, param in hf_weights.items():
568+
name_template, *ids = layer_template(from_name, num_matches=2)
569+
to_name = weight_map[name_template]
570+
param = load_param(param, from_name, dtype, verbose=debug_mode)
571+
if any(w in from_name for w in ("q_proj", "k_proj", "v_proj")):
572+
qkv = qkv_weights.setdefault(ids[0], defaultdict(dict))
573+
weight_name, weight_type = from_name.split(".")[-2:]
574+
qkv[weight_type][weight_name] = param
575+
if to_name is None:
576+
continue
577+
to_name = to_name.format(*ids)
578+
if saver is not None:
579+
param = saver.store_early(param)
580+
state_dict[to_name] = param
581+
582+
if progress_per_file is not None:
583+
pbar.update(progress_per_file)
584+
585+
if "lm_head.weight" not in state_dict:
586+
state_dict["lm_head.weight"] = state_dict["transformer.wte.weight"]
587+
588+
for i in list(qkv_weights):
589+
for weight_type in list(qkv_weights[i]):
590+
qkv = qkv_weights[i][weight_type]
591+
if len(qkv) != 3:
592+
# qkv is split across different .bin files
593+
continue
594+
q = load_param(qkv["q_proj"], f"layer {i} q {weight_type}", dtype, verbose=debug_mode)
595+
k = load_param(qkv["k_proj"], f"layer {i} k {weight_type}", dtype, verbose=debug_mode)
596+
v = load_param(qkv["v_proj"], f"layer {i} v {weight_type}", dtype, verbose=debug_mode)
597+
qkv = torch.cat((q, k, v))
598+
state_dict[f"transformer.h.{i}.attn.qkv.{weight_type}"] = qkv
599+
del qkv_weights[i][weight_type]
600+
601+
if progress_per_file is not None:
602+
pbar.update(progress_per_file)
603+
604+
536605
def qkv_reassemble(
537606
param: Union[torch.Tensor, NotYetLoadedTensor], config: Config
538607
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
@@ -624,6 +693,10 @@ def convert_hf_checkpoint(
624693
# holder to reconstitute the split q, k, v
625694
qkv_weights = {}
626695
copy_fn = partial(copy_weights_qwen_2_5, config, qkv_weights)
696+
elif model_name.lower().startswith("qwen3"):
697+
# holder to reconstitute the split q, k, v
698+
qkv_weights = {}
699+
copy_fn = partial(copy_weights_qwen_3, config, qkv_weights)
627700
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
628701
# holder to reconstitute the split q, k, v
629702
qkv_weights = {}

litgpt/scripts/convert_lit_checkpoint.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,56 @@ def copy_weights_qwen_2_5(
393393
state_dict[to_name] = param
394394

395395

396+
def copy_weights_qwen_3(
397+
config: Config,
398+
state_dict: Dict[str, torch.Tensor],
399+
lit_weights: Dict[str, Union[torch.Tensor, NotYetLoadedTensor]],
400+
untie_weights: bool = False,
401+
saver: Optional[incremental_save] = None,
402+
) -> None:
403+
weight_map = {
404+
"transformer.wte.weight": "model.embed_tokens.weight",
405+
"transformer.h.{}.norm_1.weight": "model.layers.{}.input_layernorm.weight",
406+
"transformer.h.{}.norm_2.weight": "model.layers.{}.post_attention_layernorm.weight",
407+
"transformer.h.{}.attn.proj.weight": "model.layers.{}.self_attn.o_proj.weight",
408+
"transformer.h.{}.attn.norm_q.weight": "model.layers.{}.self_attn.q_norm.weight",
409+
"transformer.h.{}.attn.norm_k.weight": "model.layers.{}.self_attn.k_norm.weight",
410+
"transformer.h.{}.mlp.fc_1.weight": "model.layers.{}.mlp.gate_proj.weight",
411+
"transformer.h.{}.mlp.fc_2.weight": "model.layers.{}.mlp.up_proj.weight",
412+
"transformer.h.{}.mlp.proj.weight": "model.layers.{}.mlp.down_proj.weight",
413+
"transformer.ln_f.weight": "model.norm.weight",
414+
"lm_head.weight": "lm_head.weight",
415+
}
416+
417+
for from_name, param in lit_weights.items():
418+
if from_name == "lm_head.weight" and untie_weights:
419+
continue
420+
name_template, *ids = layer_template(from_name, num_matches=2)
421+
param = load_param(param, from_name, None)
422+
if from_name.endswith(".attn.qkv.weight"):
423+
weight_type = from_name.split(".")[-1] # weight or bias
424+
to_names = (
425+
"model.layers.{}.self_attn.q_proj.{}".format(*ids, weight_type),
426+
"model.layers.{}.self_attn.k_proj.{}".format(*ids, weight_type),
427+
"model.layers.{}.self_attn.v_proj.{}".format(*ids, weight_type),
428+
)
429+
params = param.split(
430+
(
431+
config.n_head * config.head_size,
432+
config.n_query_groups * config.head_size,
433+
config.n_query_groups * config.head_size,
434+
)
435+
)
436+
else:
437+
to_names = (weight_map[name_template].format(*ids),)
438+
params = (param,)
439+
440+
for to_name, param in zip(to_names, params):
441+
if saver is not None:
442+
param = saver.store_early(param)
443+
state_dict[to_name] = param
444+
445+
396446
def qkv_reassemble(param: Union[torch.Tensor, NotYetLoadedTensor], config: Config) -> torch.Tensor:
397447
"""Reassemble from a normal to an interleaved placement in a QKV matrix.
398448
[Q, Q, ..., K, K, ..., V, V, ...] --> [Q, K, V, Q, K, V, ...]
@@ -437,6 +487,8 @@ def convert_lit_checkpoint(checkpoint_dir: Path, output_dir: Path) -> None:
437487
copy_fn = partial(copy_weights_phi, config)
438488
elif config.name.lower().startswith(("qwen2.5", "qwq")):
439489
copy_fn = partial(copy_weights_qwen_2_5, config)
490+
elif config.name.lower().startswith("qwen3"):
491+
copy_fn = partial(copy_weights_qwen_3, config)
440492
elif config.mlp_class_name in ("LLaMAMLP", "GemmaMLP", "LLaMAMoE"):
441493
untie_weights = "Gemma" in config.name
442494
copy_fn = partial(copy_weights_llama, config, untie_weights=untie_weights)

0 commit comments

Comments
 (0)