Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
288 changes: 88 additions & 200 deletions rev_simple_unsloth_sft+grpo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1746,26 +1746,9 @@
},
"outputs": [],
"source": [
"def train_sft(model, tokenizer, dataset):\n",
"def train_sft(model, tokenizer, sft_dataset, sft_config):\n",
" \"\"\"SFTでフォーマットを学習\"\"\"\n",
"\n",
" # SFTの設定\n",
" sft_config = SFTConfig(\n",
" dataset_text_field=\"text\",\n",
" per_device_train_batch_size=2,\n",
" gradient_accumulation_steps=2,\n",
" warmup_ratio=0.1,\n",
" num_train_epochs=2, # LLMのSFTだと一般的ですかね.\n",
" learning_rate=2e-5,\n",
" logging_steps=5,\n",
" optim=\"adamw_torch\",\n",
" weight_decay=0.01,\n",
" lr_scheduler_type=\"linear\",\n",
" seed=3407,\n",
" output_dir=\"outputs_sft\",\n",
" report_to=\"wandb\"\n",
" )\n",
"\n",
" # SFTトレーナーの作成と実行\n",
" sft_trainer = SFTTrainer(\n",
" model=model,\n",
Expand All @@ -1777,7 +1760,7 @@
" print(\"#########学習開始#########\")\n",
" sft_trainer.train()\n",
"\n",
" return model, sft_config"
" return model "
]
},
{
Expand Down Expand Up @@ -1880,57 +1863,24 @@
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"id": "pRChHaTacB4T"
},
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"reward_funcs = [reward_answer_correctness]\n",
"def train_grpo(model, tokenizer, dataset, reward_funcs, MAX_STEP):\n",
"def train_grpo(model, tokenizer, dataset, reward_funcs, grpo_config):\n",
" \"\"\"GRPOトレーニングを実行(reasoning mode有効)\"\"\"\n",
"\n",
" # トークン長を計算(enable_thinking=Trueで)\n",
" def get_token_length(example):\n",
" return len(tokenizer.apply_chat_template(\n",
" example[\"prompt\"],\n",
" add_generation_prompt=True,\n",
" tokenize=True,\n",
" ))\n",
"\n",
" dataset = dataset.map(lambda x: {\"token_length\": get_token_length(x)})\n",
" max_prompt_length = int(np.quantile(dataset[\"token_length\"], 0.9))\n",
" dataset = dataset.filter(lambda x: x[\"token_length\"] <= max_prompt_length)\n",
"\n",
" # GRPO設定\n",
" \"\"\"\n",
" 今回は特にoptimなども設定せずデフォルトで動かします.\n",
" \"\"\"\n",
" training_args = GRPOConfig(\n",
" temperature=0.6, # Qwen3の推奨値\n",
" learning_rate=5e-6, #LoRAなんでやや高めで.\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=8,\n",
" num_generations=4, # 同時に生成するサンプル数です.基本4以上が推奨です.ここを2にするとオンラインDPOを同じようなことをやっていることになります(lossの仕組みとか違いますが)\n",
" max_prompt_length=max_prompt_length,\n",
" max_completion_length=MAX_SEQ_LENGTH - max_prompt_length,\n",
" max_steps = MAX_STEP,\n",
" save_steps=100,\n",
" output_dir=\"outputs\",\n",
" report_to=\"wandb\"\n",
" )\n",
"\n",
" trainer = GRPOTrainer(\n",
" model=model,\n",
" processing_class=tokenizer,\n",
" reward_funcs=reward_funcs,\n",
" args=training_args,\n",
" args=grpo_config,\n",
" train_dataset=dataset,\n",
" )\n",
"\n",
" print(\"#########学習開始#########\")\n",
" trainer.train()\n",
" return model, training_args"
" return model"
]
},
{
Expand Down Expand Up @@ -3321,6 +3271,31 @@
"sft_dataset = prepare_sft_dataset()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"# SFTの設定\n",
"sft_config = SFTConfig(\n",
" dataset_text_field=\"text\",\n",
" per_device_train_batch_size=2,\n",
" gradient_accumulation_steps=2,\n",
" warmup_ratio=0.1,\n",
" num_train_epochs=2, # LLMのSFTだと一般的ですかね.\n",
" learning_rate=2e-5,\n",
" logging_steps=5,\n",
" optim=\"adamw_torch\",\n",
" weight_decay=0.01,\n",
" lr_scheduler_type=\"linear\",\n",
" seed=3407,\n",
" output_dir=\"outputs_sft\",\n",
" report_to=\"wandb\"\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 17,
Expand Down Expand Up @@ -3833,7 +3808,7 @@
"source": [
"print(\"\\n=== 2. SFT(教師あり学習)フェーズ ===\")\n",
"print(\"まず、正しい回答フォーマットを学習させます...\")\n",
"model, sft_config = train_sft(model, tokenizer, sft_dataset)"
"model = train_sft(model, tokenizer, sft_dataset, sft_config)"
]
},
{
Expand Down Expand Up @@ -3921,145 +3896,63 @@
]
},
{
"cell_type": "code",
"execution_count": 20,
"cell_type": "markdown",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 197,
"referenced_widgets": [
"a60127b1fcb7473d958799b0f7313d8d",
"aed570b0ec634435bf5e21bcdaddb7be",
"9956cdb890cf45769b30cdf48a8605f5",
"8d047c122f344ae3b3108c21ecd67473",
"8e5ad14df9044026a2cdaf4caaac8af0",
"529a4caafe804eb7b4bc39e24c20cae3",
"4648b1ce4c3a44888d29b60a224d93f4",
"d6a2b375489943f38514123449c3f042",
"f1631034ee1d417c9cdaf45da2f3fb97",
"8d2fd0b48d5943a380dc13bfe33dfb4c",
"bb8f6f29f2bb49628fb43fea3304397f",
"3edc302aa9ea40b7a8e659fcb91df66f",
"3c6f7e76bba34c58ad83aed8a9d7af7f",
"8bd87551cab74be3b1126f936a9dc183",
"95c4720a8fc0452cbb58fe2b94173877",
"10cab1379dc449ad8ac3779fa8a2f111",
"6884e56c7711429aae81cbb460046d3e",
"501e94f452a041cf87f9fd6c13a57f81",
"1a63c469e15843afa938f2a2efd14250",
"7d42059ea7664db5b6c992468084dfd5",
"37f5abce183d485c90eb9271885770e5",
"762b47fedb4947eb8d5cef343ecb6269",
"cbe0d1f73e2844719fd09671e0e95f75",
"7d2f7d0e72d947da9a354a025c21880f",
"aed3745450714c518a1e64892ac5569e",
"f72202b8bf8d444db31dac051b724472",
"20afc7a33eba42aabc6e4af0b2283f90",
"aa3618c5a4ce4e83840bad915008f765",
"bbc85c1b87e74a0695177115b22b6eda",
"0451c5d59dd047e4b11372c986d8e6dd",
"dbb3b5fc644544ce8c33fcbbc931d376",
"215f10a2764b418eb11c76d147ae283e",
"2c9f4aa9f9b2439aa9001c3788c6ad54",
"c0a5edb95f9647fbb71b4f6150da5e04",
"43c4759361ab42458fa9485a5499702e",
"cd689bf6174349bdbf2bc59db8b6cfe0",
"2a7cd69621ee4b559274d9b68485c326",
"1bf9394c4dc7489fa7a578c86efd902b",
"36a29633fd0b4c679011c5f4eeb6c46d",
"72b61da2d2b7440b9704f8692b9c27c5",
"4d8c3207481247cb86c9706ce57f6f73",
"a764fa30ee6b4032a731e7bcfb0c7fe8",
"00412d4ee75540b9a543d9ead7b3526d",
"adcde5a3d87643a3b32e376052c84035"
]
},
"id": "qdVm9pTbcB0U",
"outputId": "7690ce58-a158-4434-9a19-89ce72432844"
"id": "ABQR0hjAGiA0"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"=== 3. GRPO(強化学習)フェーズ ===\n",
"GRPOデータセット準備中...\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "a60127b1fcb7473d958799b0f7313d8d",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"README.md: 0%| | 0.00/603 [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "3edc302aa9ea40b7a8e659fcb91df66f",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"data/cot-00000-of-00001.parquet: 0%| | 0.00/106M [00:00<?, ?B/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "cbe0d1f73e2844719fd09671e0e95f75",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Generating cot split: 0%| | 0/19252 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "c0a5edb95f9647fbb71b4f6150da5e04",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"Map: 0%| | 0/19252 [00:00<?, ? examples/s]"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"### GRPO学習します.\n",
"\n",
"今回はmaxstepを100にしているので,100step(設定上実質バッチサイズは8なので,データセットのうち800件で学習したことになります.)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"\n",
"print(\"\\n=== 3. GRPO(強化学習)フェーズ ===\")\n",
"print(\"GRPOデータセット準備中...\")\n",
"# Add a placeholder SYSTEM_PROMPT as it's a required argument\n",
"grpo_dataset = prepare_grpo_dataset(SYSTEM_PROMPT=\"\") # 修正"
"grpo_dataset = prepare_grpo_dataset(SYSTEM_PROMPT=\"\") # 修正\n",
"\n",
"# トークン長を計算(enable_thinking=Trueで)\n",
"def get_token_length(example):\n",
" return len(tokenizer.apply_chat_template(\n",
" example[\"prompt\"],\n",
" add_generation_prompt=True,\n",
" tokenize=True,\n",
" ))\n",
"\n",
"grpo_dataset = grpo_dataset.map(lambda x: {\"token_length\": get_token_length(x)})\n",
"max_prompt_length = int(np.quantile(grpo_dataset[\"token_length\"], 0.9))\n",
"grpo_dataset = grpo_dataset.filter(lambda x: x[\"token_length\"] <= max_prompt_length)\n",
"\n",
"MAX_STEP = 100\n",
"grpo_config = GRPOConfig(\n",
" temperature=0.6, # Qwen3の推奨値\n",
" learning_rate=5e-6, #LoRAなんでやや高めで.\n",
" per_device_train_batch_size=1,\n",
" gradient_accumulation_steps=8,\n",
" num_generations=4, # 同時に生成するサンプル数です.基本4以上が推奨です.ここを2にするとオンラインDPOを同じようなことをやっていることになります(lossの仕組みとか違いますが)\n",
" max_prompt_length=max_prompt_length,\n",
" max_completion_length=MAX_SEQ_LENGTH - max_prompt_length,\n",
" max_steps = MAX_STEP,\n",
" save_steps=100,\n",
" output_dir=\"outputs\",\n",
" report_to=\"wandb\"\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "ABQR0hjAGiA0"
},
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"### GRPO学習します.\n",
"\n",
"今回はmaxstepを100にしているので,100step(設定上実質バッチサイズは8なので,データセットのうち800件で学習したことになります.)"
"reward_funcs = [reward_answer_correctness]"
]
},
{
Expand Down Expand Up @@ -4211,7 +4104,8 @@
],
"source": [
"print(\"3. GRPO学習開始...\")\n",
"model, grpo_config = train_grpo(model, tokenizer, grpo_dataset, reward_funcs, 100)"
"\n",
"model = train_grpo(model, tokenizer, grpo_dataset, reward_funcs, grpo_config)"
]
},
{
Expand Down Expand Up @@ -4264,15 +4158,6 @@
")"
]
},
{
"cell_type": "code",
"execution_count": 24,
"metadata": {
"id": "4uKT5NtLaM9Y"
},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 25,
Expand Down Expand Up @@ -4371,7 +4256,9 @@
}
],
"source": [
"model.push_to_hub(LEARNED_MODEL_NAME, revision=RUN_NAME)"
"model = model.merge_and_unload()\n",
"model.push_to_hub(LEARNED_MODEL_NAME, revision=RUN_NAME)\n",
"tokenizer.push_to_hub(LEARNED_MODEL_NAME, revision=RUN_NAME)"
]
},
{
Expand All @@ -4381,7 +4268,8 @@
},
"source": [
"# 評価の実行とwandbへのアップロード\n",
"- GPUメモリが足りない場合があります。その場合はセッション再起動をして、ここからやり直してください。"
"- GPUメモリが足りない場合があります。その場合はセッション再起動をして、ここからやり直してください。\n",
"- セッションが切れてしまった場合に、HFへのモデルのアップロードが終了していれば、インストール、HFとWandBへのログイン、sft_configとgrpo_configを設定するセルを再実行することで、以下を再実行できます。"
]
},
{
Expand Down Expand Up @@ -4774,7 +4662,7 @@
" eval_config = yaml.safe_load(f)\n",
"\n",
"eval_config[\"model_parameters\"][\"model_name\"] = LEARNED_MODEL_NAME\n",
"eval_config[\"model_parameters\"][\"run_name\"] = RUN_NAME \n",
"eval_config[\"model_parameters\"][\"revision\"] = RUN_NAME \n",
"with open(f\"{eval_code_path}/eval_config.yaml\", \"w\") as f:\n",
" yaml.dump(eval_config, f)\n"
]
Expand Down