|
6 | 6 | "tags": []
|
7 | 7 | },
|
8 | 8 | "source": [
|
9 |
| - "# Enable Long Context Length Llama-v2 (or GPT-NeoX) training with Context Parallelism.\n", |
| 9 | + "# Enable Long Context Length Llama-v2/v3 (or GPT-NeoX) training with Context Parallelism.\n", |
10 | 10 | "---\n",
|
11 | 11 | "\n",
|
12 | 12 | "This notebook's CI test result for us-west-2 is as follows. CI test results in other regions can be found at the end of the notebook.\n",
|
13 | 13 | "\n",
|
14 |
| - "\n", |
| 14 | + "\n", |
15 | 15 | "\n",
|
16 | 16 | "---\n",
|
17 | 17 | "\n",
|
18 |
| - "In this notebook, you will learn how to enable long context length distributed training of the Hugging Face Transformers Llama-v2 and GPT-NeoX models.\n", |
| 18 | + "In this notebook, you will learn how to enable long context length distributed training of the Hugging Face Transformers Llama-v2/v3 and GPT-NeoX models.\n", |
19 | 19 | "\n",
|
20 | 20 | "You can either launch this notebook from an Amazon SageMaker notebook instance which handles all credentials automatically,\n",
|
21 | 21 | "or by running it locally and setting credentials manually.\n",
|
|
74 | 74 | "metadata": {},
|
75 | 75 | "outputs": [],
|
76 | 76 | "source": [
|
77 |
| - "%pip install --upgrade \"sagemaker>=2.224\"\n", |
| 77 | + "%pip install --upgrade \"sagemaker>=2.233\"\n", |
78 | 78 | "%pip install sagemaker-experiments"
|
79 | 79 | ]
|
80 | 80 | },
|
|
187 | 187 | "source": [
|
188 | 188 | "### Choose Model\n",
|
189 | 189 | "\n",
|
190 |
| - "Choose to train either the `GPT-NeoX` or `Llama-v2` model." |
| 190 | + "Choose to train either the `GPT-NeoX`, `Llama-v2`, or `Llama-v3` model." |
191 | 191 | ]
|
192 | 192 | },
|
193 | 193 | {
|
|
196 | 196 | "metadata": {},
|
197 | 197 | "outputs": [],
|
198 | 198 | "source": [
|
199 |
| - "model_type = \"llama_v2\" # [\"gpt_neox\", \"llama_v2\"]" |
| 199 | + "model_type = \"llama_v2\" # [\"gpt_neox\", \"llama_v2\", \"llama_v3\"]" |
200 | 200 | ]
|
201 | 201 | },
|
202 | 202 | {
|
|
477 | 477 | "metadata": {},
|
478 | 478 | "outputs": [],
|
479 | 479 | "source": [
|
480 |
| - "s3_output_bucket = f\"s3://sagemaker-{region}-{account}/smp-fsdp-tp/{model_type}-outputdir/\"" |
| 480 | + "s3_output_bucket = f\"s3://sagemaker-{region}-{account}/smp-fsdp-tp/{}-outputdir/\"" |
481 | 481 | ]
|
482 | 482 | },
|
483 | 483 | {
|
|
640 | 640 | " # If you want to resume training, set checkpoint_dir to the same path as a previous job.\n",
|
641 | 641 | " SM_TRAIN_DIR = \"/opt/ml/input/data/train\"\n",
|
642 | 642 | " hyperparameters[\"checkpoint_dir\"] = f\"{SM_TRAIN_DIR}/smp-v2/{model_type}/checkpointdir\"\n",
|
643 |
| - " hyperparameters[\n", |
644 |
| - " \"training_dir\"\n", |
645 |
| - " ] = f\"{SM_TRAIN_DIR}/datasets/pytorch-gpt2-data/pytorch_gpt2/train_synthetic\"\n", |
646 |
| - " hyperparameters[\n", |
647 |
| - " \"test_dir\"\n", |
648 |
| - " ] = f\"{SM_TRAIN_DIR}/datasets/pytorch-gpt2-data/pytorch_gpt2/val_synthetic\"\n", |
| 643 | + " hyperparameters[\"training_dir\"] = (\n", |
| 644 | + " f\"{SM_TRAIN_DIR}/datasets/pytorch-gpt2-data/pytorch_gpt2/train_synthetic\"\n", |
| 645 | + " )\n", |
| 646 | + " hyperparameters[\"test_dir\"] = (\n", |
| 647 | + " f\"{SM_TRAIN_DIR}/datasets/pytorch-gpt2-data/pytorch_gpt2/val_synthetic\"\n", |
| 648 | + " )\n", |
649 | 649 | "\n",
|
650 | 650 | "# The checkpoint path (hyperparameters['checkpoint_dir'] or checkpoint_s3_uri) is not unique per job.\n",
|
651 | 651 | "# You need to modify as needed for different runs.\n",
|
|
698 | 698 | " \"num_layers\": 80,\n",
|
699 | 699 | " },\n",
|
700 | 700 | " },\n",
|
| 701 | + " \"llama_v3\": {\n", |
| 702 | + " 8: {\n", |
| 703 | + " \"hidden_width\": 4096,\n", |
| 704 | + " \"llama_intermediate_size\": 14336,\n", |
| 705 | + " \"max_context_width\": 2048,\n", |
| 706 | + " \"num_heads\": 32,\n", |
| 707 | + " \"num_layers\": 32,\n", |
| 708 | + " \"rotary_emb_base\": 500000,\n", |
| 709 | + " \"vocab_size\": 128256,\n", |
| 710 | + " },\n", |
| 711 | + " 70: {\n", |
| 712 | + " \"hidden_width\": 8192,\n", |
| 713 | + " \"llama_intermediate_size\": 28672,\n", |
| 714 | + " \"max_context_width\": 2048,\n", |
| 715 | + " \"num_heads\": 64,\n", |
| 716 | + " \"num_layers\": 80,\n", |
| 717 | + " \"rotary_emb_base\": 500000,\n", |
| 718 | + " \"vocab_size\": 128256,\n", |
| 719 | + " },\n", |
| 720 | + " },\n", |
701 | 721 | "}\n",
|
702 | 722 | "\n",
|
703 | 723 | "model_params = model_configs.get(model_type, {}).get(model_size)\n",
|
|
840 | 860 | " },\n",
|
841 | 861 | " },\n",
|
842 | 862 | " py_version=\"py311\",\n",
|
843 |
| - " framework_version=\"2.3.1\",\n", |
| 863 | + " framework_version=\"2.4.1\",\n", |
844 | 864 | " # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
|
845 | 865 | " output_path=s3_output_bucket,\n",
|
846 | 866 | " max_run=86400,\n",
|
|
923 | 943 | " },\n",
|
924 | 944 | " },\n",
|
925 | 945 | " py_version=\"py311\",\n",
|
926 |
| - " framework_version=\"2.3.1\",\n", |
| 946 | + " framework_version=\"2.4.1\",\n", |
927 | 947 | " # image_uri=$IMAGE, # Either provide `framework_version` or `image_uri`\n",
|
928 | 948 | " output_path=s3_output_bucket,\n",
|
929 | 949 | " max_run=86400,\n",
|
|
976 | 996 | "This notebook was tested in multiple regions. The test results are as follows, except for us-west-2 which is shown at the top of the notebook.\n",
|
977 | 997 | "\n",
|
978 | 998 | "\n",
|
979 |
| - "\n", |
| 999 | + "\n", |
980 | 1000 | "\n",
|
981 |
| - "\n", |
| 1001 | + "\n", |
982 | 1002 | "\n",
|
983 |
| - "\n", |
| 1003 | + "\n", |
984 | 1004 | "\n",
|
985 |
| - "\n", |
| 1005 | + "\n", |
986 | 1006 | "\n",
|
987 |
| - "\n", |
| 1007 | + "\n", |
988 | 1008 | "\n",
|
989 |
| - "\n", |
| 1009 | + "\n", |
990 | 1010 | "\n",
|
991 |
| - "\n", |
| 1011 | + "\n", |
992 | 1012 | "\n",
|
993 |
| - "\n", |
| 1013 | + "\n", |
994 | 1014 | "\n",
|
995 |
| - "\n", |
| 1015 | + "\n", |
996 | 1016 | "\n",
|
997 |
| - "\n", |
| 1017 | + "\n", |
998 | 1018 | "\n",
|
999 |
| - "\n", |
| 1019 | + "\n", |
1000 | 1020 | "\n",
|
1001 |
| - "\n", |
| 1021 | + "\n", |
1002 | 1022 | "\n",
|
1003 |
| - "\n", |
| 1023 | + "\n", |
1004 | 1024 | "\n",
|
1005 |
| - "\n", |
| 1025 | + "\n", |
1006 | 1026 | "\n",
|
1007 |
| - "\n" |
| 1027 | + "\n" |
1008 | 1028 | ]
|
1009 | 1029 | },
|
1010 | 1030 | {
|
|
0 commit comments