Skip to content

Commit a39b68f

Browse files
authored
Update SMP v2 notebooks to use latest PyTorch 2.4.1, TSM 2.6.0 release (#4770)
* Update shared scripts to PT-2.4-TSM-2.6 Add cp comm type changes. * Update SMP v2 notebooks to use the latest PyTorch 2.4.1, TSM2.6.0 release * Add SMP v2 notebook for Llama3.1 finetuning and training with FP8 on P5. * Lint and format notebooks with `black` * Update Llama-v2 notebooks to explicitly support Llama-v3 models.
1 parent 8c23941 commit a39b68f

14 files changed

+2238
-202
lines changed

build_and_train_models/sm-distributed_model_parallel_v2/gpt-neox/sm-fsdp-tp_finetuning_gpt-neox.ipynb

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
"metadata": {},
8181
"outputs": [],
8282
"source": [
83-
"%pip install --upgrade \"sagemaker>=2.224\"\n",
83+
"%pip install --upgrade \"sagemaker>=2.233\"\n",
8484
"%pip install sagemaker-experiments"
8585
]
8686
},
@@ -711,9 +711,9 @@
711711
"outputs": [],
712712
"source": [
713713
"if use_fsx:\n",
714-
" hyperparameters[\n",
715-
" \"hf_pretrained_model_name_or_dir\"\n",
716-
" ] = PRETRAINED_MODEL # f\"{SM_TRAIN_DIR}{PRETRAINED_DIR}\"\n",
714+
" hyperparameters[\"hf_pretrained_model_name_or_dir\"] = (\n",
715+
" PRETRAINED_MODEL # f\"{SM_TRAIN_DIR}{PRETRAINED_DIR}\"\n",
716+
" )\n",
717717
"else:\n",
718718
" hyperparameters[\"hf_pretrained_model_name_or_dir\"] = PRETRAINED_MODEL"
719719
]

build_and_train_models/sm-distributed_model_parallel_v2/gpt-neox/sm-fsdp-tp_train_gpt-neox.ipynb

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@
7474
"metadata": {},
7575
"outputs": [],
7676
"source": [
77-
"%pip install --upgrade \"sagemaker>=2.224\"\n",
77+
"%pip install --upgrade \"sagemaker>=2.233\"\n",
7878
"%pip install sagemaker-experiments"
7979
]
8080
},
@@ -675,12 +675,12 @@
675675
" # If you want to resume training, set checkpoint_dir to the same path as a previous job.\n",
676676
" SM_TRAIN_DIR = \"/opt/ml/input/data/train\"\n",
677677
" hyperparameters[\"checkpoint_dir\"] = f\"{SM_TRAIN_DIR}/smp-v2/{model_type}/checkpointdir\"\n",
678-
" hyperparameters[\n",
679-
" \"training_dir\"\n",
680-
" ] = f\"{SM_TRAIN_DIR}/datasets/pytorch-gpt2-data/pytorch_gpt2/train_synthetic\"\n",
681-
" hyperparameters[\n",
682-
" \"test_dir\"\n",
683-
" ] = f\"{SM_TRAIN_DIR}/datasets/pytorch-gpt2-data/pytorch_gpt2/val_synthetic\"\n",
678+
" hyperparameters[\"training_dir\"] = (\n",
679+
" f\"{SM_TRAIN_DIR}/datasets/pytorch-gpt2-data/pytorch_gpt2/train_synthetic\"\n",
680+
" )\n",
681+
" hyperparameters[\"test_dir\"] = (\n",
682+
" f\"{SM_TRAIN_DIR}/datasets/pytorch-gpt2-data/pytorch_gpt2/val_synthetic\"\n",
683+
" )\n",
684684
"\n",
685685
"# The checkpoint path (hyperparameters['checkpoint_dir'] or checkpoint_s3_uri) is not unique per job.\n",
686686
"# You need to modify as needed for different runs.\n",
@@ -874,7 +874,7 @@
874874
" },\n",
875875
" },\n",
876876
" py_version=\"py311\",\n",
877-
" framework_version=\"2.3.1\",\n",
877+
" framework_version=\"2.4.1\",\n",
878878
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
879879
" output_path=s3_output_bucket,\n",
880880
" max_run=86400,\n",
@@ -956,7 +956,7 @@
956956
" },\n",
957957
" },\n",
958958
" py_version=\"py311\",\n",
959-
" framework_version=\"2.3.1\",\n",
959+
" framework_version=\"2.4.1\",\n",
960960
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
961961
" output_path=s3_output_bucket,\n",
962962
" max_run=86400,\n",

build_and_train_models/sm-distributed_model_parallel_v2/llama_v2/sm-fsdp-tp-cp_train_llama_v2.ipynb renamed to build_and_train_models/sm-distributed_model_parallel_v2/llama_v2_v3/sm-fsdp-tp-cp_train_llama_v2_v3.ipynb

Lines changed: 50 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,16 @@
66
"tags": []
77
},
88
"source": [
9-
"# Enable Long Context Length Llama-v2 (or GPT-NeoX) training with Context Parallelism.\n",
9+
"# Enable Long Context Length Llama-v2/v3 (or GPT-NeoX) training with Context Parallelism.\n",
1010
"---\n",
1111
"\n",
1212
"This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.\n",
1313
"\n",
14-
"![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
14+
"![This us-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-2/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
1515
"\n",
1616
"---\n",
1717
"\n",
18-
"In this notebook, you will learn how to enable long context length distributed training of the Hugging Face Transformers Llama-v2 and GPT-NeoX models.\n",
18+
"In this notebook, you will learn how to enable long context length distributed training of the Hugging Face Transformers Llama-v2/v3 and GPT-NeoX models.\n",
1919
"\n",
2020
"You can either launch this notebook from an Amazon SageMaker notebook instance which handles all credentials automatically,\n",
2121
"or by running it locally and setting credentials manually.\n",
@@ -74,7 +74,7 @@
7474
"metadata": {},
7575
"outputs": [],
7676
"source": [
77-
"%pip install --upgrade \"sagemaker>=2.224\"\n",
77+
"%pip install --upgrade \"sagemaker>=2.233\"\n",
7878
"%pip install sagemaker-experiments"
7979
]
8080
},
@@ -187,7 +187,7 @@
187187
"source": [
188188
"### Choose Model\n",
189189
"\n",
190-
"Choose to train either the `GPT-NeoX` or `Llama-v2` model."
190+
"Choose to train either the `GPT-NeoX`, `Llama-v2`, or `Llama-v3` model."
191191
]
192192
},
193193
{
@@ -196,7 +196,7 @@
196196
"metadata": {},
197197
"outputs": [],
198198
"source": [
199-
"model_type = \"llama_v2\" # [\"gpt_neox\", \"llama_v2\"]"
199+
"model_type = \"llama_v2\" # [\"gpt_neox\", \"llama_v2\", \"llama_v3\"]"
200200
]
201201
},
202202
{
@@ -477,7 +477,7 @@
477477
"metadata": {},
478478
"outputs": [],
479479
"source": [
480-
"s3_output_bucket = f\"s3://sagemaker-{region}-{account}/smp-fsdp-tp/{model_type}-outputdir/\""
480+
"s3_output_bucket = f\"s3://sagemaker-{region}-{account}/smp-fsdp-tp/{}-outputdir/\""
481481
]
482482
},
483483
{
@@ -640,12 +640,12 @@
640640
" # If you want to resume training, set checkpoint_dir to the same path as a previous job.\n",
641641
" SM_TRAIN_DIR = \"/opt/ml/input/data/train\"\n",
642642
" hyperparameters[\"checkpoint_dir\"] = f\"{SM_TRAIN_DIR}/smp-v2/{model_type}/checkpointdir\"\n",
643-
" hyperparameters[\n",
644-
" \"training_dir\"\n",
645-
" ] = f\"{SM_TRAIN_DIR}/datasets/pytorch-gpt2-data/pytorch_gpt2/train_synthetic\"\n",
646-
" hyperparameters[\n",
647-
" \"test_dir\"\n",
648-
" ] = f\"{SM_TRAIN_DIR}/datasets/pytorch-gpt2-data/pytorch_gpt2/val_synthetic\"\n",
643+
" hyperparameters[\"training_dir\"] = (\n",
644+
" f\"{SM_TRAIN_DIR}/datasets/pytorch-gpt2-data/pytorch_gpt2/train_synthetic\"\n",
645+
" )\n",
646+
" hyperparameters[\"test_dir\"] = (\n",
647+
" f\"{SM_TRAIN_DIR}/datasets/pytorch-gpt2-data/pytorch_gpt2/val_synthetic\"\n",
648+
" )\n",
649649
"\n",
650650
"# The checkpoint path (hyperparameters['checkpoint_dir'] or checkpoint_s3_uri) is not unique per job.\n",
651651
"# You need to modify as needed for different runs.\n",
@@ -698,6 +698,26 @@
698698
" \"num_layers\": 80,\n",
699699
" },\n",
700700
" },\n",
701+
" \"llama_v3\": {\n",
702+
" 8: {\n",
703+
" \"hidden_width\": 4096,\n",
704+
" \"llama_intermediate_size\": 14336,\n",
705+
" \"max_context_width\": 2048,\n",
706+
" \"num_heads\": 32,\n",
707+
" \"num_layers\": 32,\n",
708+
" \"rotary_emb_base\": 500000,\n",
709+
" \"vocab_size\": 128256,\n",
710+
" },\n",
711+
" 70: {\n",
712+
" \"hidden_width\": 8192,\n",
713+
" \"llama_intermediate_size\": 28672,\n",
714+
" \"max_context_width\": 2048,\n",
715+
" \"num_heads\": 64,\n",
716+
" \"num_layers\": 80,\n",
717+
" \"rotary_emb_base\": 500000,\n",
718+
" \"vocab_size\": 128256,\n",
719+
" },\n",
720+
" },\n",
701721
"}\n",
702722
"\n",
703723
"model_params = model_configs.get(model_type, {}).get(model_size)\n",
@@ -840,7 +860,7 @@
840860
" },\n",
841861
" },\n",
842862
" py_version=\"py311\",\n",
843-
" framework_version=\"2.3.1\",\n",
863+
" framework_version=\"2.4.1\",\n",
844864
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
845865
" output_path=s3_output_bucket,\n",
846866
" max_run=86400,\n",
@@ -923,7 +943,7 @@
923943
" },\n",
924944
" },\n",
925945
" py_version=\"py311\",\n",
926-
" framework_version=\"2.3.1\",\n",
946+
" framework_version=\"2.4.1\",\n",
927947
" # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
928948
" output_path=s3_output_bucket,\n",
929949
" max_run=86400,\n",
@@ -976,35 +996,35 @@
976996
"This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n",
977997
"\n",
978998
"\n",
979-
"![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
999+
"![This us-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
9801000
"\n",
981-
"![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
1001+
"![This us-east-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-east-2/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
9821002
"\n",
983-
"![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
1003+
"![This us-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/us-west-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
9841004
"\n",
985-
"![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
1005+
"![This ca-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ca-central-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
9861006
"\n",
987-
"![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
1007+
"![This sa-east-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/sa-east-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
9881008
"\n",
989-
"![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
1009+
"![This eu-west-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
9901010
"\n",
991-
"![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
1011+
"![This eu-west-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-2/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
9921012
"\n",
993-
"![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
1013+
"![This eu-west-3 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-west-3/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
9941014
"\n",
995-
"![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
1015+
"![This eu-central-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-central-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
9961016
"\n",
997-
"![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
1017+
"![This eu-north-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/eu-north-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
9981018
"\n",
999-
"![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
1019+
"![This ap-southeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
10001020
"\n",
1001-
"![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
1021+
"![This ap-southeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-southeast-2/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
10021022
"\n",
1003-
"![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
1023+
"![This ap-northeast-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
10041024
"\n",
1005-
"![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n",
1025+
"![This ap-northeast-2 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-northeast-2/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n",
10061026
"\n",
1007-
"![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2|sm-fsdp-tp-cp_train_llama_v2.ipynb)\n"
1027+
"![This ap-south-1 badge failed to load. Check your device's internet connectivity, otherwise the service is currently unavailable](https://prod.us-west-2.tcx-beacon.docs.aws.dev/sagemaker-nb/ap-south-1/build_and_train_models|sm-distributed_model_parallel_v2|llama_v2_v3|sm-fsdp-tp-cp_train_llama_v2_v3.ipynb)\n"
10081028
]
10091029
},
10101030
{

0 commit comments

Comments
 (0)