Skip to content

Commit 8c23941

Browse files
authored
Update Hyper Parameters Llama 3.2 Fine-tuning (#4764)
* llama 3.2 fine-tuning * update hyperparameters * update hyperparameters * update hyperparameters * update hyperparameters
1 parent 3ab0937 commit 8c23941

File tree

2 files changed

+5
-11
lines changed

2 files changed

+5
-11
lines changed

generative_ai/sm-jumpstart_foundation_llama_3_2_3b_finetuning.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -695,7 +695,7 @@
695695
"- **instruction_tuned** - Whether to instruction-train the model or not. At most one of `instruction_tuned` and `chat_dataset` can be `True`. Must be `True` or `False`. Default is `False`.\n",
696696
"- **chat_dataset** - If `True`, dataset is assumed to be in chat format. At most one of `instruction_tuned` and `chat_dataset` can be `True`. Default is `False`.\n",
697697
"- **add_input_output_demarcation_key** - For an instruction tuned dataset, if this is `True`, a demarcation key (\"### Response:\\n\") is added between the prompt and completion before training. Default is `True`.\n",
698-
"- **per_device_train_batch_size** - The batch size per GPU core/CPU for training. Default is 1.\n",
698+
"- **per_device_train_batch_size** - The batch size per GPU core/CPU for training. Default is 4.\n",
699699
"- **per_device_eval_batch_size** - The batch size per GPU core/CPU for evaluation. Default is 1.\n",
700700
"- **max_train_samples** - For debugging purposes or quicker training, truncate the number of training examples to this value. Value -1 means using all of the training samples. Must be a positive integer or -1. Default is -1.\n",
701701
"- **max_val_samples** - For debugging purposes or quicker training, truncate the number of validation examples to this value. Value -1 means using all of the validation samples. Must be a positive integer or -1. Default is -1.\n",

generative_ai/sm-jumpstart_foundation_llama_3_finetuning.ipynb

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@
108108
"\n",
109109
"---\n",
110110
"\n",
111-
"First we will deploy the Llama-2 model as a SageMaker endpoint. To train/deploy 8B and 70B models, please change model_id to \"meta-textgeneration-llama-3-8b\" and \"meta-textgeneration-llama-3-70b\" respectively.\n",
111+
"First we will deploy the Llama-3 model as a SageMaker endpoint. To train/deploy 8B and 70B models, please change model_id to \"meta-textgeneration-llama-3-8b\" and \"meta-textgeneration-llama-3-70b\" respectively.\n",
112112
"\n",
113113
"---"
114114
]
@@ -193,9 +193,7 @@
193193
" },\n",
194194
"}\n",
195195
"try:\n",
196-
" response = pretrained_predictor.predict(\n",
197-
" payload, custom_attributes=\"accept_eula=false\"\n",
198-
" )\n",
196+
" response = pretrained_predictor.predict(payload, custom_attributes=\"accept_eula=false\")\n",
199197
" print_response(payload, response)\n",
200198
"except Exception as e:\n",
201199
" print(e)"
@@ -249,9 +247,7 @@
249247
"dolly_dataset = load_dataset(\"databricks/databricks-dolly-15k\", split=\"train\")\n",
250248
"\n",
251249
"# To train for question answering/information extraction, you can replace the assertion in next line to example[\"category\"] == \"closed_qa\"/\"information_extraction\".\n",
252-
"summarization_dataset = dolly_dataset.filter(\n",
253-
" lambda example: example[\"category\"] == \"summarization\"\n",
254-
")\n",
250+
"summarization_dataset = dolly_dataset.filter(lambda example: example[\"category\"] == \"summarization\")\n",
255251
"summarization_dataset = summarization_dataset.remove_columns(\"category\")\n",
256252
"\n",
257253
"# We split the dataset into two where test data is used to evaluate at the end.\n",
@@ -376,9 +372,7 @@
376372
" instance_type=\"ml.g5.12xlarge\", # For Llama-3-70b, add instance_type = \"ml.g5.48xlarge\"\n",
377373
")\n",
378374
"# By default, instruction tuning is set to false. Thus, to use instruction tuning dataset you use\n",
379-
"estimator.set_hyperparameters(\n",
380-
" instruction_tuned=\"True\", epoch=\"5\", max_input_length=\"1024\"\n",
381-
")\n",
375+
"estimator.set_hyperparameters(instruction_tuned=\"True\", epoch=\"5\", max_input_length=\"1024\")\n",
382376
"estimator.fit({\"training\": train_data_location})"
383377
]
384378
},

0 commit comments

Comments
 (0)