Skip to content

Commit 71aab67

Browse files
authored
Merge pull request #15 from llmcompe2025-team-semishigure/fix_bug
評価だけをもう一回回しやすくした
2 parents abb7752 + a54d2bd commit 71aab67

File tree

1 file changed

+88
-200
lines changed

1 file changed

+88
-200
lines changed

rev_simple_unsloth_sft+grpo.ipynb

Lines changed: 88 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -1746,26 +1746,9 @@
17461746
},
17471747
"outputs": [],
17481748
"source": [
1749-
"def train_sft(model, tokenizer, dataset):\n",
1749+
"def train_sft(model, tokenizer, sft_dataset, sft_config):\n",
17501750
" \"\"\"SFTでフォーマットを学習\"\"\"\n",
17511751
"\n",
1752-
" # SFTの設定\n",
1753-
" sft_config = SFTConfig(\n",
1754-
" dataset_text_field=\"text\",\n",
1755-
" per_device_train_batch_size=2,\n",
1756-
" gradient_accumulation_steps=2,\n",
1757-
" warmup_ratio=0.1,\n",
1758-
" num_train_epochs=2, # LLMのSFTだと一般的ですかね.\n",
1759-
" learning_rate=2e-5,\n",
1760-
" logging_steps=5,\n",
1761-
" optim=\"adamw_torch\",\n",
1762-
" weight_decay=0.01,\n",
1763-
" lr_scheduler_type=\"linear\",\n",
1764-
" seed=3407,\n",
1765-
" output_dir=\"outputs_sft\",\n",
1766-
" report_to=\"wandb\"\n",
1767-
" )\n",
1768-
"\n",
17691752
" # SFTトレーナーの作成と実行\n",
17701753
" sft_trainer = SFTTrainer(\n",
17711754
" model=model,\n",
@@ -1777,7 +1760,7 @@
17771760
" print(\"#########学習開始#########\")\n",
17781761
" sft_trainer.train()\n",
17791762
"\n",
1780-
" return model, sft_config"
1763+
" return model "
17811764
]
17821765
},
17831766
{
@@ -1880,57 +1863,24 @@
18801863
},
18811864
{
18821865
"cell_type": "code",
1883-
"execution_count": 14,
1884-
"metadata": {
1885-
"id": "pRChHaTacB4T"
1886-
},
1866+
"execution_count": null,
1867+
"metadata": {},
18871868
"outputs": [],
18881869
"source": [
1889-
"reward_funcs = [reward_answer_correctness]\n",
1890-
"def train_grpo(model, tokenizer, dataset, reward_funcs, MAX_STEP):\n",
1870+
"def train_grpo(model, tokenizer, dataset, reward_funcs, grpo_config):\n",
18911871
" \"\"\"GRPOトレーニングを実行(reasoning mode有効)\"\"\"\n",
18921872
"\n",
1893-
" # トークン長を計算(enable_thinking=Trueで)\n",
1894-
" def get_token_length(example):\n",
1895-
" return len(tokenizer.apply_chat_template(\n",
1896-
" example[\"prompt\"],\n",
1897-
" add_generation_prompt=True,\n",
1898-
" tokenize=True,\n",
1899-
" ))\n",
1900-
"\n",
1901-
" dataset = dataset.map(lambda x: {\"token_length\": get_token_length(x)})\n",
1902-
" max_prompt_length = int(np.quantile(dataset[\"token_length\"], 0.9))\n",
1903-
" dataset = dataset.filter(lambda x: x[\"token_length\"] <= max_prompt_length)\n",
1904-
"\n",
1905-
" # GRPO設定\n",
1906-
" \"\"\"\n",
1907-
" 今回は特にoptimなども設定せずデフォルトで動かします.\n",
1908-
" \"\"\"\n",
1909-
" training_args = GRPOConfig(\n",
1910-
" temperature=0.6, # Qwen3の推奨値\n",
1911-
" learning_rate=5e-6, #LoRAなんでやや高めで.\n",
1912-
" per_device_train_batch_size=1,\n",
1913-
" gradient_accumulation_steps=8,\n",
1914-
" num_generations=4, # 同時に生成するサンプル数です.基本4以上が推奨です.ここを2にするとオンラインDPOを同じようなことをやっていることになります(lossの仕組みとか違いますが)\n",
1915-
" max_prompt_length=max_prompt_length,\n",
1916-
" max_completion_length=MAX_SEQ_LENGTH - max_prompt_length,\n",
1917-
" max_steps = MAX_STEP,\n",
1918-
" save_steps=100,\n",
1919-
" output_dir=\"outputs\",\n",
1920-
" report_to=\"wandb\"\n",
1921-
" )\n",
1922-
"\n",
19231873
" trainer = GRPOTrainer(\n",
19241874
" model=model,\n",
19251875
" processing_class=tokenizer,\n",
19261876
" reward_funcs=reward_funcs,\n",
1927-
" args=training_args,\n",
1877+
" args=grpo_config,\n",
19281878
" train_dataset=dataset,\n",
19291879
" )\n",
19301880
"\n",
19311881
" print(\"#########学習開始#########\")\n",
19321882
" trainer.train()\n",
1933-
" return model, training_args"
1883+
" return model"
19341884
]
19351885
},
19361886
{
@@ -3321,6 +3271,31 @@
33213271
"sft_dataset = prepare_sft_dataset()"
33223272
]
33233273
},
3274+
{
3275+
"cell_type": "code",
3276+
"execution_count": null,
3277+
"metadata": {},
3278+
"outputs": [],
3279+
"source": [
3280+
"\n",
3281+
"# SFTの設定\n",
3282+
"sft_config = SFTConfig(\n",
3283+
" dataset_text_field=\"text\",\n",
3284+
" per_device_train_batch_size=2,\n",
3285+
" gradient_accumulation_steps=2,\n",
3286+
" warmup_ratio=0.1,\n",
3287+
" num_train_epochs=2, # LLMのSFTだと一般的ですかね.\n",
3288+
" learning_rate=2e-5,\n",
3289+
" logging_steps=5,\n",
3290+
" optim=\"adamw_torch\",\n",
3291+
" weight_decay=0.01,\n",
3292+
" lr_scheduler_type=\"linear\",\n",
3293+
" seed=3407,\n",
3294+
" output_dir=\"outputs_sft\",\n",
3295+
" report_to=\"wandb\"\n",
3296+
")\n"
3297+
]
3298+
},
33243299
{
33253300
"cell_type": "code",
33263301
"execution_count": 17,
@@ -3833,7 +3808,7 @@
38333808
"source": [
38343809
"print(\"\\n=== 2. SFT(教師あり学習)フェーズ ===\")\n",
38353810
"print(\"まず、正しい回答フォーマットを学習させます...\")\n",
3836-
"model, sft_config = train_sft(model, tokenizer, sft_dataset)"
3811+
"model = train_sft(model, tokenizer, sft_dataset, sft_config)"
38373812
]
38383813
},
38393814
{
@@ -3921,145 +3896,63 @@
39213896
]
39223897
},
39233898
{
3924-
"cell_type": "code",
3925-
"execution_count": 20,
3899+
"cell_type": "markdown",
39263900
"metadata": {
3927-
"colab": {
3928-
"base_uri": "https://localhost:8080/",
3929-
"height": 197,
3930-
"referenced_widgets": [
3931-
"a60127b1fcb7473d958799b0f7313d8d",
3932-
"aed570b0ec634435bf5e21bcdaddb7be",
3933-
"9956cdb890cf45769b30cdf48a8605f5",
3934-
"8d047c122f344ae3b3108c21ecd67473",
3935-
"8e5ad14df9044026a2cdaf4caaac8af0",
3936-
"529a4caafe804eb7b4bc39e24c20cae3",
3937-
"4648b1ce4c3a44888d29b60a224d93f4",
3938-
"d6a2b375489943f38514123449c3f042",
3939-
"f1631034ee1d417c9cdaf45da2f3fb97",
3940-
"8d2fd0b48d5943a380dc13bfe33dfb4c",
3941-
"bb8f6f29f2bb49628fb43fea3304397f",
3942-
"3edc302aa9ea40b7a8e659fcb91df66f",
3943-
"3c6f7e76bba34c58ad83aed8a9d7af7f",
3944-
"8bd87551cab74be3b1126f936a9dc183",
3945-
"95c4720a8fc0452cbb58fe2b94173877",
3946-
"10cab1379dc449ad8ac3779fa8a2f111",
3947-
"6884e56c7711429aae81cbb460046d3e",
3948-
"501e94f452a041cf87f9fd6c13a57f81",
3949-
"1a63c469e15843afa938f2a2efd14250",
3950-
"7d42059ea7664db5b6c992468084dfd5",
3951-
"37f5abce183d485c90eb9271885770e5",
3952-
"762b47fedb4947eb8d5cef343ecb6269",
3953-
"cbe0d1f73e2844719fd09671e0e95f75",
3954-
"7d2f7d0e72d947da9a354a025c21880f",
3955-
"aed3745450714c518a1e64892ac5569e",
3956-
"f72202b8bf8d444db31dac051b724472",
3957-
"20afc7a33eba42aabc6e4af0b2283f90",
3958-
"aa3618c5a4ce4e83840bad915008f765",
3959-
"bbc85c1b87e74a0695177115b22b6eda",
3960-
"0451c5d59dd047e4b11372c986d8e6dd",
3961-
"dbb3b5fc644544ce8c33fcbbc931d376",
3962-
"215f10a2764b418eb11c76d147ae283e",
3963-
"2c9f4aa9f9b2439aa9001c3788c6ad54",
3964-
"c0a5edb95f9647fbb71b4f6150da5e04",
3965-
"43c4759361ab42458fa9485a5499702e",
3966-
"cd689bf6174349bdbf2bc59db8b6cfe0",
3967-
"2a7cd69621ee4b559274d9b68485c326",
3968-
"1bf9394c4dc7489fa7a578c86efd902b",
3969-
"36a29633fd0b4c679011c5f4eeb6c46d",
3970-
"72b61da2d2b7440b9704f8692b9c27c5",
3971-
"4d8c3207481247cb86c9706ce57f6f73",
3972-
"a764fa30ee6b4032a731e7bcfb0c7fe8",
3973-
"00412d4ee75540b9a543d9ead7b3526d",
3974-
"adcde5a3d87643a3b32e376052c84035"
3975-
]
3976-
},
3977-
"id": "qdVm9pTbcB0U",
3978-
"outputId": "7690ce58-a158-4434-9a19-89ce72432844"
3901+
"id": "ABQR0hjAGiA0"
39793902
},
3980-
"outputs": [
3981-
{
3982-
"name": "stdout",
3983-
"output_type": "stream",
3984-
"text": [
3985-
"\n",
3986-
"=== 3. GRPO(強化学習)フェーズ ===\n",
3987-
"GRPOデータセット準備中...\n"
3988-
]
3989-
},
3990-
{
3991-
"data": {
3992-
"application/vnd.jupyter.widget-view+json": {
3993-
"model_id": "a60127b1fcb7473d958799b0f7313d8d",
3994-
"version_major": 2,
3995-
"version_minor": 0
3996-
},
3997-
"text/plain": [
3998-
"README.md: 0%| | 0.00/603 [00:00<?, ?B/s]"
3999-
]
4000-
},
4001-
"metadata": {},
4002-
"output_type": "display_data"
4003-
},
4004-
{
4005-
"data": {
4006-
"application/vnd.jupyter.widget-view+json": {
4007-
"model_id": "3edc302aa9ea40b7a8e659fcb91df66f",
4008-
"version_major": 2,
4009-
"version_minor": 0
4010-
},
4011-
"text/plain": [
4012-
"data/cot-00000-of-00001.parquet: 0%| | 0.00/106M [00:00<?, ?B/s]"
4013-
]
4014-
},
4015-
"metadata": {},
4016-
"output_type": "display_data"
4017-
},
4018-
{
4019-
"data": {
4020-
"application/vnd.jupyter.widget-view+json": {
4021-
"model_id": "cbe0d1f73e2844719fd09671e0e95f75",
4022-
"version_major": 2,
4023-
"version_minor": 0
4024-
},
4025-
"text/plain": [
4026-
"Generating cot split: 0%| | 0/19252 [00:00<?, ? examples/s]"
4027-
]
4028-
},
4029-
"metadata": {},
4030-
"output_type": "display_data"
4031-
},
4032-
{
4033-
"data": {
4034-
"application/vnd.jupyter.widget-view+json": {
4035-
"model_id": "c0a5edb95f9647fbb71b4f6150da5e04",
4036-
"version_major": 2,
4037-
"version_minor": 0
4038-
},
4039-
"text/plain": [
4040-
"Map: 0%| | 0/19252 [00:00<?, ? examples/s]"
4041-
]
4042-
},
4043-
"metadata": {},
4044-
"output_type": "display_data"
4045-
}
4046-
],
40473903
"source": [
3904+
"### GRPO学習します.\n",
3905+
"\n",
3906+
"今回はmaxstepを100にしているので,100step(設定上実質バッチサイズは8なので,データセットのうち800件で学習したことになります.)"
3907+
]
3908+
},
3909+
{
3910+
"cell_type": "code",
3911+
"execution_count": null,
3912+
"metadata": {},
3913+
"outputs": [],
3914+
"source": [
3915+
"\n",
40483916
"print(\"\\n=== 3. GRPO(強化学習)フェーズ ===\")\n",
40493917
"print(\"GRPOデータセット準備中...\")\n",
40503918
"# Add a placeholder SYSTEM_PROMPT as it's a required argument\n",
4051-
"grpo_dataset = prepare_grpo_dataset(SYSTEM_PROMPT=\"\") # 修正"
3919+
"grpo_dataset = prepare_grpo_dataset(SYSTEM_PROMPT=\"\") # 修正\n",
3920+
"\n",
3921+
"# トークン長を計算(enable_thinking=Trueで)\n",
3922+
"def get_token_length(example):\n",
3923+
" return len(tokenizer.apply_chat_template(\n",
3924+
" example[\"prompt\"],\n",
3925+
" add_generation_prompt=True,\n",
3926+
" tokenize=True,\n",
3927+
" ))\n",
3928+
"\n",
3929+
"grpo_dataset = grpo_dataset.map(lambda x: {\"token_length\": get_token_length(x)})\n",
3930+
"max_prompt_length = int(np.quantile(grpo_dataset[\"token_length\"], 0.9))\n",
3931+
"grpo_dataset = grpo_dataset.filter(lambda x: x[\"token_length\"] <= max_prompt_length)\n",
3932+
"\n",
3933+
"MAX_STEP = 100\n",
3934+
"grpo_config = GRPOConfig(\n",
3935+
" temperature=0.6, # Qwen3の推奨値\n",
3936+
" learning_rate=5e-6, #LoRAなんでやや高めで.\n",
3937+
" per_device_train_batch_size=1,\n",
3938+
" gradient_accumulation_steps=8,\n",
3939+
" num_generations=4, # 同時に生成するサンプル数です.基本4以上が推奨です.ここを2にするとオンラインDPOを同じようなことをやっていることになります(lossの仕組みとか違いますが)\n",
3940+
" max_prompt_length=max_prompt_length,\n",
3941+
" max_completion_length=MAX_SEQ_LENGTH - max_prompt_length,\n",
3942+
" max_steps = MAX_STEP,\n",
3943+
" save_steps=100,\n",
3944+
" output_dir=\"outputs\",\n",
3945+
" report_to=\"wandb\"\n",
3946+
")"
40523947
]
40533948
},
40543949
{
4055-
"cell_type": "markdown",
4056-
"metadata": {
4057-
"id": "ABQR0hjAGiA0"
4058-
},
3950+
"cell_type": "code",
3951+
"execution_count": null,
3952+
"metadata": {},
3953+
"outputs": [],
40593954
"source": [
4060-
"### GRPO学習します.\n",
4061-
"\n",
4062-
"今回はmaxstepを100にしているので,100step(設定上実質バッチサイズは8なので,データセットのうち800件で学習したことになります.)"
3955+
"reward_funcs = [reward_answer_correctness]"
40633956
]
40643957
},
40653958
{
@@ -4211,7 +4104,8 @@
42114104
],
42124105
"source": [
42134106
"print(\"3. GRPO学習開始...\")\n",
4214-
"model, grpo_config = train_grpo(model, tokenizer, grpo_dataset, reward_funcs, 100)"
4107+
"\n",
4108+
"model = train_grpo(model, tokenizer, grpo_dataset, reward_funcs, grpo_config)"
42154109
]
42164110
},
42174111
{
@@ -4264,15 +4158,6 @@
42644158
")"
42654159
]
42664160
},
4267-
{
4268-
"cell_type": "code",
4269-
"execution_count": 24,
4270-
"metadata": {
4271-
"id": "4uKT5NtLaM9Y"
4272-
},
4273-
"outputs": [],
4274-
"source": []
4275-
},
42764161
{
42774162
"cell_type": "code",
42784163
"execution_count": 25,
@@ -4371,7 +4256,9 @@
43714256
}
43724257
],
43734258
"source": [
4374-
"model.push_to_hub(LEARNED_MODEL_NAME, revision=RUN_NAME)"
4259+
"model = model.merge_and_unload()\n",
4260+
"model.push_to_hub(LEARNED_MODEL_NAME, revision=RUN_NAME)\n",
4261+
"tokenizer.push_to_hub(LEARNED_MODEL_NAME, revision=RUN_NAME)"
43754262
]
43764263
},
43774264
{
@@ -4381,7 +4268,8 @@
43814268
},
43824269
"source": [
43834270
"# 評価の実行とwandbへのアップロード\n",
4384-
"- GPUメモリが足りない場合があります。その場合はセッション再起動をして、ここからやり直してください。"
4271+
"- GPUメモリが足りない場合があります。その場合はセッション再起動をして、ここからやり直してください。\n",
4272+
"- セッションが切れてしまった場合に、HFへのモデルのアップロードが終了していれば、インストール、HFとWandBへのログイン、sft_configとgrpo_configを設定するセルを再実行することで、以下を再実行できます。"
43854273
]
43864274
},
43874275
{
@@ -4774,7 +4662,7 @@
47744662
" eval_config = yaml.safe_load(f)\n",
47754663
"\n",
47764664
"eval_config[\"model_parameters\"][\"model_name\"] = LEARNED_MODEL_NAME\n",
4777-
"eval_config[\"model_parameters\"][\"run_name\"] = RUN_NAME \n",
4665+
"eval_config[\"model_parameters\"][\"revision\"] = RUN_NAME \n",
47784666
"with open(f\"{eval_code_path}/eval_config.yaml\", \"w\") as f:\n",
47794667
" yaml.dump(eval_config, f)\n"
47804668
]

0 commit comments

Comments
 (0)