|
67 | 67 | "This guide walks you through how to fine-tune Gemma on a mobile game NPC dataset using Hugging Face [Transformers](https://huggingface.co/docs/transformers/index) and [TRL](https://huggingface.co/docs/trl/index). You will learn:\n", |
68 | 68 | "\n", |
69 | 69 | "- Setup development environment\n", |
70 | | - "- Create and prepare the fine-tuning dataset\n", |
| 70 | + "- Prepare the fine-tuning dataset\n", |
71 | 71 | "- Full model fine-tuning Gemma using TRL and the SFTTrainer\n", |
72 | 72 | "- Test Model Inference and vibe checks\n", |
73 | 73 | "\n", |
|
256 | 256 | } |
257 | 257 | ], |
258 | 258 | "source": [ |
259 | | - "npc_type = \"martian\" #@param [\"martian\", \"venusian\"]\n", |
260 | | - "\n", |
261 | 259 | "from datasets import load_dataset\n", |
262 | 260 | "\n", |
263 | 261 | "def create_conversation(sample):\n", |
|
268 | 266 | " ]\n", |
269 | 267 | " }\n", |
270 | 268 | "\n", |
271 | | - "# Load dataset from the hub\n", |
| 269 | + "npc_type = \"martian\" #@param [\"martian\", \"venusian\"]\n", |
| 270 | + "\n", |
| 271 | + "# Load dataset from the Hub\n", |
272 | 272 | "dataset = load_dataset(\"bebechien/MobileGameNPC\", npc_type, split=\"train\")\n", |
273 | 273 | "\n", |
274 | 274 | "# Convert dataset to conversational format\n", |
275 | | - "dataset = dataset.map(create_conversation, remove_columns=dataset.features,batched=False)\n", |
276 | | - "# split dataset into 80% training samples and 20% test samples\n", |
| 275 | + "dataset = dataset.map(create_conversation, remove_columns=dataset.features, batched=False)\n", |
| 276 | + "\n", |
| 277 | + "# Split dataset into 80% training samples and 20% test samples\n", |
277 | 278 | "dataset = dataset.train_test_split(test_size=0.2, shuffle=False)\n", |
278 | 279 | "\n", |
279 | 280 | "# Print formatted user prompt\n", |
|
332 | 333 | "id": "M3w3b9-O4fDz" |
333 | 334 | }, |
334 | 335 | "source": [ |
335 | | - "## Before fine-tune (Base model)\n", |
| 336 | + "## Before fine-tune\n", |
336 | 337 | "\n", |
337 | | - "The output below shows that the model is a generalist and isn't specifically trained for your NPC's character." |
| 338 | + "The output below shows that the out-of-the-box capabilities may not be good enough for this use case." |
338 | 339 | ] |
339 | 340 | }, |
340 | 341 | { |
|
720 | 721 | "# Access the log history\n", |
721 | 722 | "log_history = trainer.state.log_history\n", |
722 | 723 | "\n", |
723 | | - "# Extract training loss and global steps\n", |
724 | | - "train_losses = []\n", |
725 | | - "eval_losses = []\n", |
726 | | - "epoch_train = []\n", |
727 | | - "epoch_eval = []\n", |
728 | | - "\n", |
729 | | - "for log in log_history:\n", |
730 | | - " if \"loss\" in log: # Check for training loss\n", |
731 | | - " train_losses.append(log[\"loss\"])\n", |
732 | | - " epoch_train.append(log[\"epoch\"])\n", |
733 | | - " if \"eval_loss\" in log: # Check for validation loss\n", |
734 | | - " eval_losses.append(log[\"eval_loss\"])\n", |
735 | | - " epoch_eval.append(log['epoch'])\n", |
| 724 | + "# Extract training / validation loss\n", |
| 725 | + "train_losses = [log[\"loss\"] for log in log_history if \"loss\" in log]\n", |
| 726 | + "epoch_train = [log[\"epoch\"] for log in log_history if \"loss\" in log]\n", |
| 727 | + "eval_losses = [log[\"eval_loss\"] for log in log_history if \"eval_loss\" in log]\n", |
| 728 | + "epoch_eval = [log[\"epoch\"] for log in log_history if \"eval_loss\" in log]\n", |
736 | 729 | "\n", |
737 | 730 | "# Plot the training loss\n", |
738 | 731 | "plt.plot(epoch_train, train_losses, label=\"Training Loss\")\n", |
|
771 | 764 | "\n", |
772 | 765 | "After the training is done, you'll want to evaluate and test your model. You can load different samples from the test dataset and evaluate the model on those samples.\n", |
773 | 766 | "\n", |
774 | | - "For this particular use case, the best model is a matter of preference. Interestingly, what we'd normally call 'overfitting' can be very useful for a game NPC. It forces the model to forget general information and instead lock onto the specific persona and characteristics it was trained on, ensuring it stays consistently in character.\n", |
775 | | - "\n", |
776 | | - "> Note: Evaluating generative AI models is not a trivial task since one input can have multiple correct outputs. This guide only focuses on manual evaluation and vibe checks." |
| 767 | + "For this particular use case, the best model is a matter of preference. Interestingly, what we'd normally call 'overfitting' can be very useful for a game NPC. It forces the model to forget general information and instead lock onto the specific persona and characteristics it was trained on, ensuring it stays consistently in character.\n" |
777 | 768 | ] |
778 | 769 | }, |
779 | 770 | { |
|
0 commit comments