Skip to content

Commit 1086772

Browse files
ko3n1gjQizhangcuichenx
authored
Add DeepSeek-R1 Distillation NeMo 2.0 tutorial (#12187) (#12355)
* add distillation tutorial * add reason data generation tutorial * minor fixes * add readme * fix qwen importer * fix minor bugs * remove hack * pylint * Apply isort and black reformatting * pylint --------- Signed-off-by: Chen Cui <chcui@nvidia.com> Signed-off-by: Chen Cui <cxcui@alumni.cmu.edu> Signed-off-by: cuichenx <cuichenx@users.noreply.github.com> Co-authored-by: bigbigQI <871052938@qq.com> Co-authored-by: lark zhang <larkz@nvidia.com> Co-authored-by: Chen Cui <chcui@nvidia.com> Co-authored-by: Chen Cui <cxcui@alumni.cmu.edu> Co-authored-by: cuichenx <cuichenx@users.noreply.github.com>
1 parent e265bc0 commit 1086772

File tree

4 files changed

+1173
-3
lines changed

4 files changed

+1173
-3
lines changed

nemo/collections/llm/gpt/model/qwen2.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636

3737
@dataclass
3838
class Qwen2Config(GPTConfig):
39+
"""
40+
Base config for Qwen 2 Models
41+
"""
42+
3943
normalization: str = "RMSNorm"
4044
activation_func: Callable = F.silu
4145
gated_linear_unit: bool = True
@@ -54,6 +58,10 @@ class Qwen2Config(GPTConfig):
5458

5559
@dataclass
5660
class Qwen2Config500M(Qwen2Config):
61+
"""
62+
Config for Qwen 2 0.5B: https://huggingface.co/Qwen/Qwen2-0.5B
63+
"""
64+
5765
num_layers: int = 24
5866
hidden_size: int = 896
5967
num_attention_heads: int = 14
@@ -63,6 +71,10 @@ class Qwen2Config500M(Qwen2Config):
6371

6472
@dataclass
6573
class Qwen2Config1P5B(Qwen2Config):
74+
"""
75+
Config for Qwen 2 1.5B: https://huggingface.co/Qwen/Qwen2-1.5B
76+
"""
77+
6678
num_layers: int = 28
6779
hidden_size: int = 1536
6880
num_attention_heads: int = 12
@@ -72,6 +84,10 @@ class Qwen2Config1P5B(Qwen2Config):
7284

7385
@dataclass
7486
class Qwen2Config7B(Qwen2Config):
87+
"""
88+
Config for Qwen 2 7B: https://huggingface.co/Qwen/Qwen2-7B
89+
"""
90+
7591
num_layers: int = 28
7692
hidden_size: int = 3584
7793
num_attention_heads: int = 28
@@ -82,17 +98,24 @@ class Qwen2Config7B(Qwen2Config):
8298

8399
@dataclass
84100
class Qwen2Config72B(Qwen2Config):
101+
"""
102+
Config for Qwen 2 72B: https://huggingface.co/Qwen/Qwen2-72B
103+
"""
104+
85105
num_layers: int = 80
86106
hidden_size: int = 8192
87107
num_attention_heads: int = 64
88108
num_query_groups: int = 8
89109
ffn_hidden_size: int = 29568
90110
vocab_size: int = 152064
91111
layernorm_epsilon: float = 1e-5
92-
vocab_size: int = 152064
93112

94113

95114
class Qwen2Model(GPTModel):
115+
"""
116+
Base model for Qwen 2
117+
"""
118+
96119
def __init__(
97120
self,
98121
config: Annotated[Optional[Qwen2Config], Config[Qwen2Config]] = None,
@@ -105,6 +128,7 @@ def __init__(
105128

106129
@io.model_importer(Qwen2Model, "hf")
107130
class HFQwen2Importer(io.ModelConnector["AutoModelForCausalLM", Qwen2Model]):
131+
# pylint: disable=C0115,C0116
108132
def init(self) -> Qwen2Model:
109133
return Qwen2Model(self.config, tokenizer=self.tokenizer)
110134

@@ -163,6 +187,8 @@ def config(self) -> Qwen2Config:
163187
make_vocab_size_divisible_by=128,
164188
rotary_base=source.rope_theta,
165189
share_embeddings_and_output_weights=False,
190+
vocab_size=source.vocab_size,
191+
seq_length=source.max_position_embeddings,
166192
fp16=(dtype_from_hf(source) == torch.float16),
167193
bf16=(dtype_from_hf(source) == torch.bfloat16),
168194
params_dtype=dtype_from_hf(source),
@@ -173,6 +199,7 @@ def config(self) -> Qwen2Config:
173199

174200
@io.model_exporter(Qwen2Model, "hf")
175201
class HFQwen2Exporter(io.ModelConnector[Qwen2Model, "AutoModelForCausalLM"]):
202+
# pylint: disable=C0115,C0116
176203
def init(self, dtype=torch.bfloat16) -> "AutoModelForCausalLM":
177204
from transformers import AutoModelForCausalLM
178205
from transformers.modeling_utils import no_init_weights
@@ -288,7 +315,6 @@ def _import_qkv_bias(ctx: io.TransformCTX, q, k, v):
288315
head_num = megatron_config.num_attention_heads
289316
num_query_groups = megatron_config.num_query_groups
290317
heads_per_group = head_num // num_query_groups
291-
hidden_size = megatron_config.hidden_size
292318
head_size = megatron_config.kv_channels
293319

294320
new_q_tensor_shape = (head_num, head_size)
@@ -360,7 +386,6 @@ def _export_qkv_bias(ctx: io.TransformCTX, qkv_bias):
360386
head_num = megatron_config.num_attention_heads
361387
num_query_groups = megatron_config.num_query_groups
362388
heads_per_group = head_num // num_query_groups
363-
hidden_size = megatron_config.hidden_size
364389
head_size = megatron_config.kv_channels
365390
qkv_total_dim = head_num + 2 * num_query_groups
366391

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
Distilling the Reasoning Ability of DeepSeek R1 into Qwen with the NeMo 2.0 Framework
2+
==================================================================================
3+
4+
DeepSeek R1 is an open-source large language model dedicated to solving logical reasoning tasks. It employs a Mixture of Experts (MoE) architecture and boasts 671B parameters. Through reinforcement learning, it has been trained to perform deep thinking (generating long-chain-of-thought), excelling in reasoning tasks and various specialized fields such as mathematics, programming, and scientific analysis.
5+
6+
Moreover, as per the `DeepSeek-R1 <https://arxiv.org/abs/2501.12948>`_ paper, the reasoning patterns of larger models can be distilled into smaller ones. Specifically, we can distill long-chain-of-thought (long-CoT) data that encapsulates reasoning processes from DeepSeek-R1 and directly fine-tune open-source models like Qwen and Llama. This simple distillation approach greatly enhances the reasoning capabilities of smaller models.
7+
8+
9+
To illustrate the complete distillation process, we have prepared two notebooks demonstrating how to extract reasoning data from DeepSeek-R1 using the NIM API, and how to train models with the distilled data.
10+
11+
* `generate_reasoning_data.ipynb <./generate_reasoning_data.ipynb>`_ demonstrates the process of distilling reasoning data from DeepSeek-R1 using the NIM API.
12+
* `qwen2_distill_nemo.ipynb <./qwen2_distill_nemo.ipynb>`_ shows how to train open-source models with the distilled data.

0 commit comments

Comments
 (0)