@@ -359,7 +359,7 @@ Use a specific scheduler (e.g., `gke.io/topology-aware-auto`) using `--scheduler
359359 --scheduler gke.io/topology-aware-auto
360360```
361361
362- ## 9. Sophisticated Workloads: MaxText Llama3.1-8B
362+ ## 9. Sophisticated Workloads: MaxText Llama3.1-8B (TPU v6e)
363363
364364This section describes how to deploy a more complex workload, specifically training a Llama3.1-8B model using MaxText on a TPU v6e cluster.
365365
@@ -578,7 +578,263 @@ kubectl get pods --namespace default -l jobset.sigs.k8s.io/jobset-name=maxtext-l
578578kubectl logs <POD_NAME> --namespace default
579579```
580580
581- ## 10. Cleanup
581+ ## 10. Sophisticated Workloads: MaxText Llama3.1-8B (TPU v7x)
582+
583+ This section describes how to deploy the MaxText workload specifically optimized for TPU v7x (Ironwood) hardware.
584+
585+ ### 10.1 Prepare MaxText v7x Workload Directory
586+
587+ Create a directory named `maxtext_workload_v7x` and place the following files inside it.
588+
589+ #### `cluster-toolkit/maxtext_workload_v7x/requirements.txt`
590+
591+ ```text
592+ psutil
593+ jaxtyping
594+ tiktoken
595+ sentencepiece
596+ ray
597+ fastapi
598+ uvicorn
599+ portpicker
600+ pydantic
601+ ninja
602+ Pillow
603+ gcsfs
604+ omegaconf
605+ jsonlines
606+ PyYAML
607+ safetensors
608+ tabulate
609+ tensorstore
610+ transformers
611+ datasets
612+ evaluate
613+ nltk
614+ pandas
615+ ml_collections
616+ ml_dtypes
617+ pathwaysutils
618+ orbax
619+ grain
620+ tensorflow_text
621+ tensorflow_datasets
622+ tqdm
623+ ```
624+
625+ #### `cluster-toolkit/maxtext_workload_v7x/Dockerfile`
626+
627+ ```dockerfile
628+ # Use the recommended base image for TPU7x
629+ FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.8.2-rev1
630+
631+ # Set the working directory
632+ WORKDIR /deps
633+
634+ # Install system dependencies
635+ RUN apt-get update && apt-get install -y dnsutils
636+
637+ # Install MaxText dependencies
638+ RUN pip install google-cloud-monitoring
639+
640+ # Install Python requirements
641+ COPY requirements.txt .
642+ RUN pip install -r requirements.txt
643+
644+ # Clone MaxText
645+ RUN git clone https://github.com/AI-Hypercomputer/maxtext.git /app \
646+ && cd /app \
647+ && git checkout maxtext-tutorial-v1.0.0
648+
649+ # Set working directory to MaxText root for the runner
650+ WORKDIR /app
651+
652+ # Copy the wrapper script
653+ COPY run_maxtext.sh /app/run_maxtext.sh
654+ RUN chmod +x /app/run_maxtext.sh
655+
656+ # Entrypoint is left to default or overridden by gcluster
657+ ```
658+
659+ #### `cluster-toolkit/maxtext_workload_v7x/run_maxtext.sh`
660+
661+ ```bash
662+ #!/bin/bash
663+
664+ # Exit on error
665+ set -e
666+
667+ echo "Starting MaxText Workload..."
668+
669+ # 1. Set environment variables
670+ export ENABLE_PATHWAYS_PERSISTENCE=' 1'
671+ # Combine all the required XLA flags into LIBTPU_INIT_ARGS
672+ export LIBTPU_INIT_ARGS=" --xla_tpu_scoped_vmem_limit_kib=61440 --xla_tpu_bf16_emission_mode=NATIVE_EMISSION --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true --xla_tpu_use_single_sparse_core_for_all_gather_offload=true "
673+
674+ export JAX_PLATFORMS="tpu,cpu"
675+ export ENABLE_PJRT_COMPATIBILITY=true
676+ export PYTHONPATH=$PYTHONPATH:$(pwd)/src
677+ export JAX_TRACEBACK_FILTERING=off
678+
679+ # 2. Extract arguments
680+ OUTPUT_DIR=${1}
681+ if [ -z "$OUTPUT_DIR" ]; then
682+ echo "Error: Output directory argument missing."
683+ echo "Usage: $0 <output_gcs_bucket>"
684+ exit 1
685+ fi
686+
687+ MODEL_NAME=${2:-"llama3.1-8b"}
688+
689+ echo "LIBTPU_INIT_ARGS=$LIBTPU_INIT_ARGS"
690+ echo "OUTPUT_DIR=$OUTPUT_DIR"
691+ echo "MODEL_NAME=$MODEL_NAME"
692+
693+ # 3. Run training
694+ python3 src/MaxText/train.py src/MaxText/configs/base.yml \
695+ model_name=$MODEL_NAME \
696+ skip_jax_distributed_system=False \
697+ dtype=bfloat16 \
698+ per_device_batch_size=1 \
699+ ici_fsdp_parallelism=64 \
700+ max_target_length=4096 \
701+ profiler=xplane \
702+ profile_periodically_period=10000 \
703+ async_checkpointing=False \
704+ enable_checkpointing=False \
705+ use_iota_embed=True \
706+ remat_policy=custom \
707+ decoder_layer_input=offload \
708+ query_proj=offload \
709+ key_proj=offload \
710+ value_proj=offload \
711+ out_proj=offload \
712+ dataset_type=synthetic \
713+ opt_type=adamw \
714+ mu_dtype=bfloat16 \
715+ tokenizer_type=tiktoken \
716+ tokenizer_path=assets/tokenizer_llama3.tiktoken \
717+ sa_use_fused_bwd_kernel=True \
718+ attention=flash \
719+ steps=30 \
720+ base_output_directory=$OUTPUT_DIR \
721+ use_vertex_tensorboard=false
722+ ```
723+
724+ #### `cluster-toolkit/maxtext_workload_v7x/build.sh`
725+
726+ ```bash
727+ #!/bin/bash
728+
729+ # Get current project
730+ PROJECT=$(gcloud config get-value project)
731+
732+ if [ -z "$PROJECT" ]; then
733+ echo "Error: Could not determine GCP project. Please run ' gcloud config set project < PROJECT_ID> ' "
734+ exit 1
735+ fi
736+
737+ IMAGE_NAME=gcr.io/$PROJECT/maxtext-runner:latest
738+
739+ echo "Building image $IMAGE_NAME using Cloud Build..."
740+ gcloud builds submit --tag $IMAGE_NAME .
741+
742+ echo "Image built successfully!"
743+ echo "You can now submit the job with:"
744+ echo " gcluster job submit --image $IMAGE_NAME --command ' bash run_maxtext.sh < OUTPUT_DIR> ' "
745+ ```
746+
747+ #### `cluster-toolkit/maxtext_workload_v7x/submit.sh`
748+
749+ ```bash
750+ #!/bin/bash
751+
752+ # Configuration - UPDATE THESE
753+ CLUSTER_NAME="tpu7xpkv"
754+ ZONE="us-central1-c"
755+ OUTPUT_DIR="gs://gke-aishared-gsc-dev/maxtext_output_7x"
756+
757+ # Look up project
758+ PROJECT=$(gcloud config get-value project)
759+
760+ if [ -z "$PROJECT" ]; then
761+ echo "Error: Could not determine GCP project. Please run ' gcloud config set project < PROJECT_ID> ' "
762+ exit 1
763+ fi
764+
765+ IMAGE_NAME=gcr.io/$PROJECT/maxtext-runner:latest
766+
767+ echo "Ensuring permissions for tpu7xpkv-gke-wl-sa..."
768+ # Note: Ensure the SA matches the one created by the 7x blueprint (tpu7xpkv-gke-wl-sa)
769+ gcloud projects add-iam-policy-binding $PROJECT --member="serviceAccount:tpu7xpkv-gke-wl-sa@${PROJECT}.iam.gserviceaccount.com" --role="roles/logging.logWriter" --quiet
770+ gcloud projects add-iam-policy-binding $PROJECT --member="serviceAccount:tpu7xpkv-gke-wl-sa@${PROJECT}.iam.gserviceaccount.com" --role="roles/storage.admin" --quiet
771+ gcloud projects add-iam-policy-binding $PROJECT --member="serviceAccount:tpu7xpkv-gke-wl-sa@${PROJECT}.iam.gserviceaccount.com" --role="roles/monitoring.metricWriter" --quiet
772+ gcloud projects add-iam-policy-binding $PROJECT --member="serviceAccount:tpu7xpkv-gke-wl-sa@${PROJECT}.iam.gserviceaccount.com" --role="roles/logging.viewer" --quiet
773+ gcloud projects add-iam-policy-binding $PROJECT --member="serviceAccount:tpu7xpkv-gke-wl-sa@${PROJECT}.iam.gserviceaccount.com" --role="roles/storage.objectViewer" --quiet
774+
775+ echo "Ensuring permissions for node pool service account..."
776+ gcloud projects add-iam-policy-binding $PROJECT --member="serviceAccount:tpu7xpkv-gke-np-sa@${PROJECT}.iam.gserviceaccount.com" --role="roles/artifactregistry.reader" --quiet
777+ gcloud projects add-iam-policy-binding $PROJECT --member="serviceAccount:tpu7xpkv-gke-np-sa@${PROJECT}.iam.gserviceaccount.com" --role="roles/storage.objectViewer" --quiet
778+
779+ echo "Submitting MaxText job to cluster $CLUSTER_NAME..."
780+
781+ # Navigate to cluster-toolkit root if running from maxtext_workload_v7x
782+ if [ -f "../gcluster" ]; then
783+ GCLUSTER="../gcluster"
784+ else
785+ GCLUSTER="./gcluster"
786+ fi
787+
788+ $GCLUSTER job submit \
789+ --name maxtext-llama3-1-final-tpu7x-32 \
790+ --cluster $CLUSTER_NAME \
791+ --cluster-region us-central1 \
792+ --image $IMAGE_NAME \
793+ --command "cd /app && sed -i ' s/use_vertex_tensorboard=false/use_vertex_tensorboard=false run_name=llama3-1-7x-test1/g' run_maxtext.sh && bash run_maxtext.sh $OUTPUT_DIR" \
794+ --accelerator tpu7x-32 \
795+ --nodes 1 \
796+ --vms-per-slice 8 \
797+ --topology 2x4x4 \
798+ --priority medium \
799+ --service-account workload-identity-k8s-sa
800+ ```
801+
802+ ### 10.2 Build and Submit
803+
804+ ```bash
805+ cd maxtext_workload_v7x
806+ ./build.sh
807+ ./submit.sh
808+ ```
809+
810+ ### 10.3 Verify Job and Logs
811+
812+ TPU v7x utilizes Megacore, which initializes 2 logical devices per chip. For a 32-chip slice (2x4x4 topology), you should see 64 logical devices in the logs.
813+
814+ **Using `gcluster`**:
815+
816+ ```bash
817+ # List jobs
818+ ./gcluster job list --project <YOUR_PROJECT_ID> --cluster tpu7xpkv --cluster-region us-central1
819+
820+ # View logs
821+ ./gcluster job logs maxtext-llama3-1-final-tpu7x-32 --project <YOUR_PROJECT_ID> --cluster tpu7xpkv --cluster-region us-central1
822+ ```
823+
824+ **Verification Highlights**:
825+ In the logs, look for successful JAX initialization and step processing:
826+
827+ ```text
828+ System Information: Jax Version: 0.8.2
829+ System Information: Jax Backend: PJRT C API
830+ TFRT TPU7x
831+ Num_devices: 64, shape (1, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1)
832+ ...
833+ completed step: 4, seconds: 1.182, TFLOP/s/device: 167.154, Tokens/s/device: 3464.384, loss: 10.187
834+ completed step: 5, seconds: 1.186, TFLOP/s/device: 166.597, Tokens/s/device: 3452.840, loss: 9.184
835+ ```
836+
837+ ## 11. Cleanup
582838
583839To avoid incurring unnecessary costs, destroy the deployed GKE cluster and its resources:
584840
0 commit comments