|
1746 | 1746 | },
|
1747 | 1747 | "outputs": [],
|
1748 | 1748 | "source": [
|
1749 |
| - "def train_sft(model, tokenizer, dataset):\n", |
| 1749 | + "def train_sft(model, tokenizer, sft_dataset, sft_config):\n", |
1750 | 1750 | " \"\"\"SFTでフォーマットを学習\"\"\"\n",
|
1751 | 1751 | "\n",
|
1752 |
| - " # SFTの設定\n", |
1753 |
| - " sft_config = SFTConfig(\n", |
1754 |
| - " dataset_text_field=\"text\",\n", |
1755 |
| - " per_device_train_batch_size=2,\n", |
1756 |
| - " gradient_accumulation_steps=2,\n", |
1757 |
| - " warmup_ratio=0.1,\n", |
1758 |
| - " num_train_epochs=2, # LLMのSFTだと一般的ですかね.\n", |
1759 |
| - " learning_rate=2e-5,\n", |
1760 |
| - " logging_steps=5,\n", |
1761 |
| - " optim=\"adamw_torch\",\n", |
1762 |
| - " weight_decay=0.01,\n", |
1763 |
| - " lr_scheduler_type=\"linear\",\n", |
1764 |
| - " seed=3407,\n", |
1765 |
| - " output_dir=\"outputs_sft\",\n", |
1766 |
| - " report_to=\"wandb\"\n", |
1767 |
| - " )\n", |
1768 |
| - "\n", |
1769 | 1752 | " # SFTトレーナーの作成と実行\n",
|
1770 | 1753 | " sft_trainer = SFTTrainer(\n",
|
1771 | 1754 | " model=model,\n",
|
|
1777 | 1760 | " print(\"#########学習開始#########\")\n",
|
1778 | 1761 | " sft_trainer.train()\n",
|
1779 | 1762 | "\n",
|
1780 |
| - " return model, sft_config" |
| 1763 | + " return model " |
1781 | 1764 | ]
|
1782 | 1765 | },
|
1783 | 1766 | {
|
|
1880 | 1863 | },
|
1881 | 1864 | {
|
1882 | 1865 | "cell_type": "code",
|
1883 |
| - "execution_count": 14, |
1884 |
| - "metadata": { |
1885 |
| - "id": "pRChHaTacB4T" |
1886 |
| - }, |
| 1866 | + "execution_count": null, |
| 1867 | + "metadata": {}, |
1887 | 1868 | "outputs": [],
|
1888 | 1869 | "source": [
|
1889 |
| - "reward_funcs = [reward_answer_correctness]\n", |
1890 |
| - "def train_grpo(model, tokenizer, dataset, reward_funcs, MAX_STEP):\n", |
| 1870 | + "def train_grpo(model, tokenizer, dataset, reward_funcs, grpo_config):\n", |
1891 | 1871 | " \"\"\"GRPOトレーニングを実行(reasoning mode有効)\"\"\"\n",
|
1892 | 1872 | "\n",
|
1893 |
| - " # トークン長を計算(enable_thinking=Trueで)\n", |
1894 |
| - " def get_token_length(example):\n", |
1895 |
| - " return len(tokenizer.apply_chat_template(\n", |
1896 |
| - " example[\"prompt\"],\n", |
1897 |
| - " add_generation_prompt=True,\n", |
1898 |
| - " tokenize=True,\n", |
1899 |
| - " ))\n", |
1900 |
| - "\n", |
1901 |
| - " dataset = dataset.map(lambda x: {\"token_length\": get_token_length(x)})\n", |
1902 |
| - " max_prompt_length = int(np.quantile(dataset[\"token_length\"], 0.9))\n", |
1903 |
| - " dataset = dataset.filter(lambda x: x[\"token_length\"] <= max_prompt_length)\n", |
1904 |
| - "\n", |
1905 |
| - " # GRPO設定\n", |
1906 |
| - " \"\"\"\n", |
1907 |
| - " 今回は特にoptimなども設定せずデフォルトで動かします.\n", |
1908 |
| - " \"\"\"\n", |
1909 |
| - " training_args = GRPOConfig(\n", |
1910 |
| - " temperature=0.6, # Qwen3の推奨値\n", |
1911 |
| - " learning_rate=5e-6, #LoRAなんでやや高めで.\n", |
1912 |
| - " per_device_train_batch_size=1,\n", |
1913 |
| - " gradient_accumulation_steps=8,\n", |
1914 |
| - " num_generations=4, # 同時に生成するサンプル数です.基本4以上が推奨です.ここを2にするとオンラインDPOを同じようなことをやっていることになります(lossの仕組みとか違いますが)\n", |
1915 |
| - " max_prompt_length=max_prompt_length,\n", |
1916 |
| - " max_completion_length=MAX_SEQ_LENGTH - max_prompt_length,\n", |
1917 |
| - " max_steps = MAX_STEP,\n", |
1918 |
| - " save_steps=100,\n", |
1919 |
| - " output_dir=\"outputs\",\n", |
1920 |
| - " report_to=\"wandb\"\n", |
1921 |
| - " )\n", |
1922 |
| - "\n", |
1923 | 1873 | " trainer = GRPOTrainer(\n",
|
1924 | 1874 | " model=model,\n",
|
1925 | 1875 | " processing_class=tokenizer,\n",
|
1926 | 1876 | " reward_funcs=reward_funcs,\n",
|
1927 |
| - " args=training_args,\n", |
| 1877 | + " args=grpo_config,\n", |
1928 | 1878 | " train_dataset=dataset,\n",
|
1929 | 1879 | " )\n",
|
1930 | 1880 | "\n",
|
1931 | 1881 | " print(\"#########学習開始#########\")\n",
|
1932 | 1882 | " trainer.train()\n",
|
1933 |
| - " return model, training_args" |
| 1883 | + " return model" |
1934 | 1884 | ]
|
1935 | 1885 | },
|
1936 | 1886 | {
|
|
3321 | 3271 | "sft_dataset = prepare_sft_dataset()"
|
3322 | 3272 | ]
|
3323 | 3273 | },
|
| 3274 | + { |
| 3275 | + "cell_type": "code", |
| 3276 | + "execution_count": null, |
| 3277 | + "metadata": {}, |
| 3278 | + "outputs": [], |
| 3279 | + "source": [ |
| 3280 | + "\n", |
| 3281 | + "# SFTの設定\n", |
| 3282 | + "sft_config = SFTConfig(\n", |
| 3283 | + " dataset_text_field=\"text\",\n", |
| 3284 | + " per_device_train_batch_size=2,\n", |
| 3285 | + " gradient_accumulation_steps=2,\n", |
| 3286 | + " warmup_ratio=0.1,\n", |
| 3287 | + " num_train_epochs=2, # LLMのSFTだと一般的ですかね.\n", |
| 3288 | + " learning_rate=2e-5,\n", |
| 3289 | + " logging_steps=5,\n", |
| 3290 | + " optim=\"adamw_torch\",\n", |
| 3291 | + " weight_decay=0.01,\n", |
| 3292 | + " lr_scheduler_type=\"linear\",\n", |
| 3293 | + " seed=3407,\n", |
| 3294 | + " output_dir=\"outputs_sft\",\n", |
| 3295 | + " report_to=\"wandb\"\n", |
| 3296 | + ")\n" |
| 3297 | + ] |
| 3298 | + }, |
3324 | 3299 | {
|
3325 | 3300 | "cell_type": "code",
|
3326 | 3301 | "execution_count": 17,
|
|
3833 | 3808 | "source": [
|
3834 | 3809 | "print(\"\\n=== 2. SFT(教師あり学習)フェーズ ===\")\n",
|
3835 | 3810 | "print(\"まず、正しい回答フォーマットを学習させます...\")\n",
|
3836 |
| - "model, sft_config = train_sft(model, tokenizer, sft_dataset)" |
| 3811 | + "model = train_sft(model, tokenizer, sft_dataset, sft_config)" |
3837 | 3812 | ]
|
3838 | 3813 | },
|
3839 | 3814 | {
|
|
3921 | 3896 | ]
|
3922 | 3897 | },
|
3923 | 3898 | {
|
3924 |
| - "cell_type": "code", |
3925 |
| - "execution_count": 20, |
| 3899 | + "cell_type": "markdown", |
3926 | 3900 | "metadata": {
|
3927 |
| - "colab": { |
3928 |
| - "base_uri": "https://localhost:8080/", |
3929 |
| - "height": 197, |
3930 |
| - "referenced_widgets": [ |
3931 |
| - "a60127b1fcb7473d958799b0f7313d8d", |
3932 |
| - "aed570b0ec634435bf5e21bcdaddb7be", |
3933 |
| - "9956cdb890cf45769b30cdf48a8605f5", |
3934 |
| - "8d047c122f344ae3b3108c21ecd67473", |
3935 |
| - "8e5ad14df9044026a2cdaf4caaac8af0", |
3936 |
| - "529a4caafe804eb7b4bc39e24c20cae3", |
3937 |
| - "4648b1ce4c3a44888d29b60a224d93f4", |
3938 |
| - "d6a2b375489943f38514123449c3f042", |
3939 |
| - "f1631034ee1d417c9cdaf45da2f3fb97", |
3940 |
| - "8d2fd0b48d5943a380dc13bfe33dfb4c", |
3941 |
| - "bb8f6f29f2bb49628fb43fea3304397f", |
3942 |
| - "3edc302aa9ea40b7a8e659fcb91df66f", |
3943 |
| - "3c6f7e76bba34c58ad83aed8a9d7af7f", |
3944 |
| - "8bd87551cab74be3b1126f936a9dc183", |
3945 |
| - "95c4720a8fc0452cbb58fe2b94173877", |
3946 |
| - "10cab1379dc449ad8ac3779fa8a2f111", |
3947 |
| - "6884e56c7711429aae81cbb460046d3e", |
3948 |
| - "501e94f452a041cf87f9fd6c13a57f81", |
3949 |
| - "1a63c469e15843afa938f2a2efd14250", |
3950 |
| - "7d42059ea7664db5b6c992468084dfd5", |
3951 |
| - "37f5abce183d485c90eb9271885770e5", |
3952 |
| - "762b47fedb4947eb8d5cef343ecb6269", |
3953 |
| - "cbe0d1f73e2844719fd09671e0e95f75", |
3954 |
| - "7d2f7d0e72d947da9a354a025c21880f", |
3955 |
| - "aed3745450714c518a1e64892ac5569e", |
3956 |
| - "f72202b8bf8d444db31dac051b724472", |
3957 |
| - "20afc7a33eba42aabc6e4af0b2283f90", |
3958 |
| - "aa3618c5a4ce4e83840bad915008f765", |
3959 |
| - "bbc85c1b87e74a0695177115b22b6eda", |
3960 |
| - "0451c5d59dd047e4b11372c986d8e6dd", |
3961 |
| - "dbb3b5fc644544ce8c33fcbbc931d376", |
3962 |
| - "215f10a2764b418eb11c76d147ae283e", |
3963 |
| - "2c9f4aa9f9b2439aa9001c3788c6ad54", |
3964 |
| - "c0a5edb95f9647fbb71b4f6150da5e04", |
3965 |
| - "43c4759361ab42458fa9485a5499702e", |
3966 |
| - "cd689bf6174349bdbf2bc59db8b6cfe0", |
3967 |
| - "2a7cd69621ee4b559274d9b68485c326", |
3968 |
| - "1bf9394c4dc7489fa7a578c86efd902b", |
3969 |
| - "36a29633fd0b4c679011c5f4eeb6c46d", |
3970 |
| - "72b61da2d2b7440b9704f8692b9c27c5", |
3971 |
| - "4d8c3207481247cb86c9706ce57f6f73", |
3972 |
| - "a764fa30ee6b4032a731e7bcfb0c7fe8", |
3973 |
| - "00412d4ee75540b9a543d9ead7b3526d", |
3974 |
| - "adcde5a3d87643a3b32e376052c84035" |
3975 |
| - ] |
3976 |
| - }, |
3977 |
| - "id": "qdVm9pTbcB0U", |
3978 |
| - "outputId": "7690ce58-a158-4434-9a19-89ce72432844" |
| 3901 | + "id": "ABQR0hjAGiA0" |
3979 | 3902 | },
|
3980 |
| - "outputs": [ |
3981 |
| - { |
3982 |
| - "name": "stdout", |
3983 |
| - "output_type": "stream", |
3984 |
| - "text": [ |
3985 |
| - "\n", |
3986 |
| - "=== 3. GRPO(強化学習)フェーズ ===\n", |
3987 |
| - "GRPOデータセット準備中...\n" |
3988 |
| - ] |
3989 |
| - }, |
3990 |
| - { |
3991 |
| - "data": { |
3992 |
| - "application/vnd.jupyter.widget-view+json": { |
3993 |
| - "model_id": "a60127b1fcb7473d958799b0f7313d8d", |
3994 |
| - "version_major": 2, |
3995 |
| - "version_minor": 0 |
3996 |
| - }, |
3997 |
| - "text/plain": [ |
3998 |
| - "README.md: 0%| | 0.00/603 [00:00<?, ?B/s]" |
3999 |
| - ] |
4000 |
| - }, |
4001 |
| - "metadata": {}, |
4002 |
| - "output_type": "display_data" |
4003 |
| - }, |
4004 |
| - { |
4005 |
| - "data": { |
4006 |
| - "application/vnd.jupyter.widget-view+json": { |
4007 |
| - "model_id": "3edc302aa9ea40b7a8e659fcb91df66f", |
4008 |
| - "version_major": 2, |
4009 |
| - "version_minor": 0 |
4010 |
| - }, |
4011 |
| - "text/plain": [ |
4012 |
| - "data/cot-00000-of-00001.parquet: 0%| | 0.00/106M [00:00<?, ?B/s]" |
4013 |
| - ] |
4014 |
| - }, |
4015 |
| - "metadata": {}, |
4016 |
| - "output_type": "display_data" |
4017 |
| - }, |
4018 |
| - { |
4019 |
| - "data": { |
4020 |
| - "application/vnd.jupyter.widget-view+json": { |
4021 |
| - "model_id": "cbe0d1f73e2844719fd09671e0e95f75", |
4022 |
| - "version_major": 2, |
4023 |
| - "version_minor": 0 |
4024 |
| - }, |
4025 |
| - "text/plain": [ |
4026 |
| - "Generating cot split: 0%| | 0/19252 [00:00<?, ? examples/s]" |
4027 |
| - ] |
4028 |
| - }, |
4029 |
| - "metadata": {}, |
4030 |
| - "output_type": "display_data" |
4031 |
| - }, |
4032 |
| - { |
4033 |
| - "data": { |
4034 |
| - "application/vnd.jupyter.widget-view+json": { |
4035 |
| - "model_id": "c0a5edb95f9647fbb71b4f6150da5e04", |
4036 |
| - "version_major": 2, |
4037 |
| - "version_minor": 0 |
4038 |
| - }, |
4039 |
| - "text/plain": [ |
4040 |
| - "Map: 0%| | 0/19252 [00:00<?, ? examples/s]" |
4041 |
| - ] |
4042 |
| - }, |
4043 |
| - "metadata": {}, |
4044 |
| - "output_type": "display_data" |
4045 |
| - } |
4046 |
| - ], |
4047 | 3903 | "source": [
|
| 3904 | + "### GRPO学習します.\n", |
| 3905 | + "\n", |
| 3906 | + "今回はmaxstepを100にしているので,100step(設定上実質バッチサイズは8なので,データセットのうち800件で学習したことになります.)" |
| 3907 | + ] |
| 3908 | + }, |
| 3909 | + { |
| 3910 | + "cell_type": "code", |
| 3911 | + "execution_count": null, |
| 3912 | + "metadata": {}, |
| 3913 | + "outputs": [], |
| 3914 | + "source": [ |
| 3915 | + "\n", |
4048 | 3916 | "print(\"\\n=== 3. GRPO(強化学習)フェーズ ===\")\n",
|
4049 | 3917 | "print(\"GRPOデータセット準備中...\")\n",
|
4050 | 3918 | "# Add a placeholder SYSTEM_PROMPT as it's a required argument\n",
|
4051 |
| - "grpo_dataset = prepare_grpo_dataset(SYSTEM_PROMPT=\"\") # 修正" |
| 3919 | + "grpo_dataset = prepare_grpo_dataset(SYSTEM_PROMPT=\"\") # 修正\n", |
| 3920 | + "\n", |
| 3921 | + "# トークン長を計算(enable_thinking=Trueで)\n", |
| 3922 | + "def get_token_length(example):\n", |
| 3923 | + " return len(tokenizer.apply_chat_template(\n", |
| 3924 | + " example[\"prompt\"],\n", |
| 3925 | + " add_generation_prompt=True,\n", |
| 3926 | + " tokenize=True,\n", |
| 3927 | + " ))\n", |
| 3928 | + "\n", |
| 3929 | + "grpo_dataset = grpo_dataset.map(lambda x: {\"token_length\": get_token_length(x)})\n", |
| 3930 | + "max_prompt_length = int(np.quantile(grpo_dataset[\"token_length\"], 0.9))\n", |
| 3931 | + "grpo_dataset = grpo_dataset.filter(lambda x: x[\"token_length\"] <= max_prompt_length)\n", |
| 3932 | + "\n", |
| 3933 | + "MAX_STEP = 100\n", |
| 3934 | + "grpo_config = GRPOConfig(\n", |
| 3935 | + " temperature=0.6, # Qwen3の推奨値\n", |
| 3936 | + " learning_rate=5e-6, #LoRAなんでやや高めで.\n", |
| 3937 | + " per_device_train_batch_size=1,\n", |
| 3938 | + " gradient_accumulation_steps=8,\n", |
| 3939 | + " num_generations=4, # 同時に生成するサンプル数です.基本4以上が推奨です.ここを2にするとオンラインDPOを同じようなことをやっていることになります(lossの仕組みとか違いますが)\n", |
| 3940 | + " max_prompt_length=max_prompt_length,\n", |
| 3941 | + " max_completion_length=MAX_SEQ_LENGTH - max_prompt_length,\n", |
| 3942 | + " max_steps = MAX_STEP,\n", |
| 3943 | + " save_steps=100,\n", |
| 3944 | + " output_dir=\"outputs\",\n", |
| 3945 | + " report_to=\"wandb\"\n", |
| 3946 | + ")" |
4052 | 3947 | ]
|
4053 | 3948 | },
|
4054 | 3949 | {
|
4055 |
| - "cell_type": "markdown", |
4056 |
| - "metadata": { |
4057 |
| - "id": "ABQR0hjAGiA0" |
4058 |
| - }, |
| 3950 | + "cell_type": "code", |
| 3951 | + "execution_count": null, |
| 3952 | + "metadata": {}, |
| 3953 | + "outputs": [], |
4059 | 3954 | "source": [
|
4060 |
| - "### GRPO学習します.\n", |
4061 |
| - "\n", |
4062 |
| - "今回はmaxstepを100にしているので,100step(設定上実質バッチサイズは8なので,データセットのうち800件で学習したことになります.)" |
| 3955 | + "reward_funcs = [reward_answer_correctness]" |
4063 | 3956 | ]
|
4064 | 3957 | },
|
4065 | 3958 | {
|
|
4211 | 4104 | ],
|
4212 | 4105 | "source": [
|
4213 | 4106 | "print(\"3. GRPO学習開始...\")\n",
|
4214 |
| - "model, grpo_config = train_grpo(model, tokenizer, grpo_dataset, reward_funcs, 100)" |
| 4107 | + "\n", |
| 4108 | + "model = train_grpo(model, tokenizer, grpo_dataset, reward_funcs, grpo_config)" |
4215 | 4109 | ]
|
4216 | 4110 | },
|
4217 | 4111 | {
|
|
4264 | 4158 | ")"
|
4265 | 4159 | ]
|
4266 | 4160 | },
|
4267 |
| - { |
4268 |
| - "cell_type": "code", |
4269 |
| - "execution_count": 24, |
4270 |
| - "metadata": { |
4271 |
| - "id": "4uKT5NtLaM9Y" |
4272 |
| - }, |
4273 |
| - "outputs": [], |
4274 |
| - "source": [] |
4275 |
| - }, |
4276 | 4161 | {
|
4277 | 4162 | "cell_type": "code",
|
4278 | 4163 | "execution_count": 25,
|
|
4371 | 4256 | }
|
4372 | 4257 | ],
|
4373 | 4258 | "source": [
|
4374 |
| - "model.push_to_hub(LEARNED_MODEL_NAME, revision=RUN_NAME)" |
| 4259 | + "model = model.merge_and_unload()\n", |
| 4260 | + "model.push_to_hub(LEARNED_MODEL_NAME, revision=RUN_NAME)\n", |
| 4261 | + "tokenizer.push_to_hub(LEARNED_MODEL_NAME, revision=RUN_NAME)" |
4375 | 4262 | ]
|
4376 | 4263 | },
|
4377 | 4264 | {
|
|
4381 | 4268 | },
|
4382 | 4269 | "source": [
|
4383 | 4270 | "# 評価の実行とwandbへのアップロード\n",
|
4384 |
| - "- GPUメモリが足りない場合があります。その場合はセッション再起動をして、ここからやり直してください。" |
| 4271 | + "- GPUメモリが足りない場合があります。その場合はセッション再起動をして、ここからやり直してください。\n", |
| 4272 | + "- セッションが切れてしまった場合に、HFへのモデルのアップロードが終了していれば、インストール、HFとWandBへのログイン、sft_configとgrpo_configを設定するセルを再実行することで、以下を再実行できます。" |
4385 | 4273 | ]
|
4386 | 4274 | },
|
4387 | 4275 | {
|
|
4774 | 4662 | " eval_config = yaml.safe_load(f)\n",
|
4775 | 4663 | "\n",
|
4776 | 4664 | "eval_config[\"model_parameters\"][\"model_name\"] = LEARNED_MODEL_NAME\n",
|
4777 |
| - "eval_config[\"model_parameters\"][\"run_name\"] = RUN_NAME \n", |
| 4665 | + "eval_config[\"model_parameters\"][\"revision\"] = RUN_NAME \n", |
4778 | 4666 | "with open(f\"{eval_code_path}/eval_config.yaml\", \"w\") as f:\n",
|
4779 | 4667 | " yaml.dump(eval_config, f)\n"
|
4780 | 4668 | ]
|
|
0 commit comments