Skip to content

Commit 212e5a0

Browse files
v7x testing and mem handling
1 parent 36e1398 commit 212e5a0

File tree

2 files changed

+265
-5
lines changed

2 files changed

+265
-5
lines changed

docs/replication_guide.md

Lines changed: 258 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,7 +359,7 @@ Use a specific scheduler (e.g., `gke.io/topology-aware-auto`) using `--scheduler
359359
--scheduler gke.io/topology-aware-auto
360360
```
361361
362-
## 9. Sophisticated Workloads: MaxText Llama3.1-8B
362+
## 9. Sophisticated Workloads: MaxText Llama3.1-8B (TPU v6e)
363363
364364
This section describes how to deploy a more complex workload, specifically training a Llama3.1-8B model using MaxText on a TPU v6e cluster.
365365
@@ -578,7 +578,263 @@ kubectl get pods --namespace default -l jobset.sigs.k8s.io/jobset-name=maxtext-l
578578
kubectl logs <POD_NAME> --namespace default
579579
```
580580
581-
## 10. Cleanup
581+
## 10. Sophisticated Workloads: MaxText Llama3.1-8B (TPU v7x)
582+
583+
This section describes how to deploy the MaxText workload specifically optimized for TPU v7x (Ironwood) hardware.
584+
585+
### 10.1 Prepare MaxText v7x Workload Directory
586+
587+
Create a directory named `maxtext_workload_v7x` and place the following files inside it.
588+
589+
#### `cluster-toolkit/maxtext_workload_v7x/requirements.txt`
590+
591+
```text
592+
psutil
593+
jaxtyping
594+
tiktoken
595+
sentencepiece
596+
ray
597+
fastapi
598+
uvicorn
599+
portpicker
600+
pydantic
601+
ninja
602+
Pillow
603+
gcsfs
604+
omegaconf
605+
jsonlines
606+
PyYAML
607+
safetensors
608+
tabulate
609+
tensorstore
610+
transformers
611+
datasets
612+
evaluate
613+
nltk
614+
pandas
615+
ml_collections
616+
ml_dtypes
617+
pathwaysutils
618+
orbax
619+
grain
620+
tensorflow_text
621+
tensorflow_datasets
622+
tqdm
623+
```
624+
625+
#### `cluster-toolkit/maxtext_workload_v7x/Dockerfile`
626+
627+
```dockerfile
628+
# Use the recommended base image for TPU7x
629+
FROM us-docker.pkg.dev/cloud-tpu-images/jax-ai-image/tpu:jax0.8.2-rev1
630+
631+
# Set the working directory
632+
WORKDIR /deps
633+
634+
# Install system dependencies
635+
RUN apt-get update && apt-get install -y dnsutils
636+
637+
# Install MaxText dependencies
638+
RUN pip install google-cloud-monitoring
639+
640+
# Install Python requirements
641+
COPY requirements.txt .
642+
RUN pip install -r requirements.txt
643+
644+
# Clone MaxText
645+
RUN git clone https://github.com/AI-Hypercomputer/maxtext.git /app \
646+
&& cd /app \
647+
&& git checkout maxtext-tutorial-v1.0.0
648+
649+
# Set working directory to MaxText root for the runner
650+
WORKDIR /app
651+
652+
# Copy the wrapper script
653+
COPY run_maxtext.sh /app/run_maxtext.sh
654+
RUN chmod +x /app/run_maxtext.sh
655+
656+
# Entrypoint is left to default or overridden by gcluster
657+
```
658+
659+
#### `cluster-toolkit/maxtext_workload_v7x/run_maxtext.sh`
660+
661+
```bash
662+
#!/bin/bash
663+
664+
# Exit on error
665+
set -e
666+
667+
echo "Starting MaxText Workload..."
668+
669+
# 1. Set environment variables
670+
export ENABLE_PATHWAYS_PERSISTENCE='1'
671+
# Combine all the required XLA flags into LIBTPU_INIT_ARGS
672+
export LIBTPU_INIT_ARGS=" --xla_tpu_scoped_vmem_limit_kib=61440 --xla_tpu_bf16_emission_mode=NATIVE_EMISSION --xla_tpu_enable_sparse_core_collective_offload_all_reduce=true --xla_tpu_use_single_sparse_core_for_all_gather_offload=true "
673+
674+
export JAX_PLATFORMS="tpu,cpu"
675+
export ENABLE_PJRT_COMPATIBILITY=true
676+
export PYTHONPATH=$PYTHONPATH:$(pwd)/src
677+
export JAX_TRACEBACK_FILTERING=off
678+
679+
# 2. Extract arguments
680+
OUTPUT_DIR=${1}
681+
if [ -z "$OUTPUT_DIR" ]; then
682+
echo "Error: Output directory argument missing."
683+
echo "Usage: $0 <output_gcs_bucket>"
684+
exit 1
685+
fi
686+
687+
MODEL_NAME=${2:-"llama3.1-8b"}
688+
689+
echo "LIBTPU_INIT_ARGS=$LIBTPU_INIT_ARGS"
690+
echo "OUTPUT_DIR=$OUTPUT_DIR"
691+
echo "MODEL_NAME=$MODEL_NAME"
692+
693+
# 3. Run training
694+
python3 src/MaxText/train.py src/MaxText/configs/base.yml \
695+
model_name=$MODEL_NAME \
696+
skip_jax_distributed_system=False \
697+
dtype=bfloat16 \
698+
per_device_batch_size=1 \
699+
ici_fsdp_parallelism=64 \
700+
max_target_length=4096 \
701+
profiler=xplane \
702+
profile_periodically_period=10000 \
703+
async_checkpointing=False \
704+
enable_checkpointing=False \
705+
use_iota_embed=True \
706+
remat_policy=custom \
707+
decoder_layer_input=offload \
708+
query_proj=offload \
709+
key_proj=offload \
710+
value_proj=offload \
711+
out_proj=offload \
712+
dataset_type=synthetic \
713+
opt_type=adamw \
714+
mu_dtype=bfloat16 \
715+
tokenizer_type=tiktoken \
716+
tokenizer_path=assets/tokenizer_llama3.tiktoken \
717+
sa_use_fused_bwd_kernel=True \
718+
attention=flash \
719+
steps=30 \
720+
base_output_directory=$OUTPUT_DIR \
721+
use_vertex_tensorboard=false
722+
```
723+
724+
#### `cluster-toolkit/maxtext_workload_v7x/build.sh`
725+
726+
```bash
727+
#!/bin/bash
728+
729+
# Get current project
730+
PROJECT=$(gcloud config get-value project)
731+
732+
if [ -z "$PROJECT" ]; then
733+
echo "Error: Could not determine GCP project. Please run 'gcloud config set project <PROJECT_ID>'"
734+
exit 1
735+
fi
736+
737+
IMAGE_NAME=gcr.io/$PROJECT/maxtext-runner:latest
738+
739+
echo "Building image $IMAGE_NAME using Cloud Build..."
740+
gcloud builds submit --tag $IMAGE_NAME .
741+
742+
echo "Image built successfully!"
743+
echo "You can now submit the job with:"
744+
echo " gcluster job submit --image $IMAGE_NAME --command 'bash run_maxtext.sh <OUTPUT_DIR>'"
745+
```
746+
747+
#### `cluster-toolkit/maxtext_workload_v7x/submit.sh`
748+
749+
```bash
750+
#!/bin/bash
751+
752+
# Configuration - UPDATE THESE
753+
CLUSTER_NAME="tpu7xpkv"
754+
ZONE="us-central1-c"
755+
OUTPUT_DIR="gs://gke-aishared-gsc-dev/maxtext_output_7x"
756+
757+
# Look up project
758+
PROJECT=$(gcloud config get-value project)
759+
760+
if [ -z "$PROJECT" ]; then
761+
echo "Error: Could not determine GCP project. Please run 'gcloud config set project <PROJECT_ID>'"
762+
exit 1
763+
fi
764+
765+
IMAGE_NAME=gcr.io/$PROJECT/maxtext-runner:latest
766+
767+
echo "Ensuring permissions for tpu7xpkv-gke-wl-sa..."
768+
# Note: Ensure the SA matches the one created by the 7x blueprint (tpu7xpkv-gke-wl-sa)
769+
gcloud projects add-iam-policy-binding $PROJECT --member="serviceAccount:tpu7xpkv-gke-wl-sa@${PROJECT}.iam.gserviceaccount.com" --role="roles/logging.logWriter" --quiet
770+
gcloud projects add-iam-policy-binding $PROJECT --member="serviceAccount:tpu7xpkv-gke-wl-sa@${PROJECT}.iam.gserviceaccount.com" --role="roles/storage.admin" --quiet
771+
gcloud projects add-iam-policy-binding $PROJECT --member="serviceAccount:tpu7xpkv-gke-wl-sa@${PROJECT}.iam.gserviceaccount.com" --role="roles/monitoring.metricWriter" --quiet
772+
gcloud projects add-iam-policy-binding $PROJECT --member="serviceAccount:tpu7xpkv-gke-wl-sa@${PROJECT}.iam.gserviceaccount.com" --role="roles/logging.viewer" --quiet
773+
gcloud projects add-iam-policy-binding $PROJECT --member="serviceAccount:tpu7xpkv-gke-wl-sa@${PROJECT}.iam.gserviceaccount.com" --role="roles/storage.objectViewer" --quiet
774+
775+
echo "Ensuring permissions for node pool service account..."
776+
gcloud projects add-iam-policy-binding $PROJECT --member="serviceAccount:tpu7xpkv-gke-np-sa@${PROJECT}.iam.gserviceaccount.com" --role="roles/artifactregistry.reader" --quiet
777+
gcloud projects add-iam-policy-binding $PROJECT --member="serviceAccount:tpu7xpkv-gke-np-sa@${PROJECT}.iam.gserviceaccount.com" --role="roles/storage.objectViewer" --quiet
778+
779+
echo "Submitting MaxText job to cluster $CLUSTER_NAME..."
780+
781+
# Navigate to cluster-toolkit root if running from maxtext_workload_v7x
782+
if [ -f "../gcluster" ]; then
783+
GCLUSTER="../gcluster"
784+
else
785+
GCLUSTER="./gcluster"
786+
fi
787+
788+
$GCLUSTER job submit \
789+
--name maxtext-llama3-1-final-tpu7x-32 \
790+
--cluster $CLUSTER_NAME \
791+
--cluster-region us-central1 \
792+
--image $IMAGE_NAME \
793+
--command "cd /app && sed -i 's/use_vertex_tensorboard=false/use_vertex_tensorboard=false run_name=llama3-1-7x-test1/g' run_maxtext.sh && bash run_maxtext.sh $OUTPUT_DIR" \
794+
--accelerator tpu7x-32 \
795+
--nodes 1 \
796+
--vms-per-slice 8 \
797+
--topology 2x4x4 \
798+
--priority medium \
799+
--service-account workload-identity-k8s-sa
800+
```
801+
802+
### 10.2 Build and Submit
803+
804+
```bash
805+
cd maxtext_workload_v7x
806+
./build.sh
807+
./submit.sh
808+
```
809+
810+
### 10.3 Verify Job and Logs
811+
812+
TPU v7x utilizes Megacore, which initializes 2 logical devices per chip. For a 32-chip slice (2x4x4 topology), you should see 64 logical devices in the logs.
813+
814+
**Using `gcluster`**:
815+
816+
```bash
817+
# List jobs
818+
./gcluster job list --project <YOUR_PROJECT_ID> --cluster tpu7xpkv --cluster-region us-central1
819+
820+
# View logs
821+
./gcluster job logs maxtext-llama3-1-final-tpu7x-32 --project <YOUR_PROJECT_ID> --cluster tpu7xpkv --cluster-region us-central1
822+
```
823+
824+
**Verification Highlights**:
825+
In the logs, look for successful JAX initialization and step processing:
826+
827+
```text
828+
System Information: Jax Version: 0.8.2
829+
System Information: Jax Backend: PJRT C API
830+
TFRT TPU7x
831+
Num_devices: 64, shape (1, 1, 64, 1, 1, 1, 1, 1, 1, 1, 1, 1)
832+
...
833+
completed step: 4, seconds: 1.182, TFLOP/s/device: 167.154, Tokens/s/device: 3464.384, loss: 10.187
834+
completed step: 5, seconds: 1.186, TFLOP/s/device: 166.597, Tokens/s/device: 3452.840, loss: 9.184
835+
```
836+
837+
## 11. Cleanup
582838
583839
To avoid incurring unnecessary costs, destroy the deployed GKE cluster and its resources:
584840

pkg/orchestrator/gke/gke_orchestrator.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,8 @@ func (g *GKEOrchestrator) ensureClusterQueueCoverage(localQueueName string) erro
334334
patch := `[
335335
{"op": "add", "path": "/spec/resourceGroups/0/coveredResources/-", "value": "cpu"},
336336
{"op": "add", "path": "/spec/resourceGroups/0/coveredResources/-", "value": "memory"},
337-
{"op": "add", "path": "/spec/resourceGroups/0/flavors/0/resources/-", "value": {"name": "cpu", "nominalQuota": "500"}},
338-
{"op": "add", "path": "/spec/resourceGroups/0/flavors/0/resources/-", "value": {"name": "memory", "nominalQuota": "2000Gi"}}
337+
{"op": "add", "path": "/spec/resourceGroups/0/flavors/0/resources/-", "value": {"name": "cpu", "nominalQuota": "2000"}},
338+
{"op": "add", "path": "/spec/resourceGroups/0/flavors/0/resources/-", "value": {"name": "memory", "nominalQuota": "20000Gi"}}
339339
]`
340340

341341
res := g.executor.ExecuteCommand("kubectl", "patch", "clusterqueue", cqName, "--type", "json", "-p", patch)
@@ -1050,6 +1050,9 @@ func (g *GKEOrchestrator) GenerateGKENodeSelectorLabel(acceleratorType string) s
10501050
if strings.HasPrefix(acceleratorType, "v6e-") {
10511051
return "tpu-v6e-slice"
10521052
}
1053+
if strings.Contains(acceleratorType, "tpu7x") {
1054+
return "tpu7x"
1055+
}
10531056
switch acceleratorType {
10541057
case "nvidia-tesla-a100":
10551058
return "nvidia-tesla-a100"
@@ -1123,6 +1126,7 @@ var defaultResourceLimits = map[string][4]string{
11231126
"tpu-v5-lite-podslice": {"1", "4Gi", "", "4"},
11241127
"tpu-v5-lite-device": {"1", "4Gi", "", "4"},
11251128
"tpu-v6e-slice": {"48", "240Gi", "", "4"},
1129+
"tpu7x": {"96", "800Gi", "", "4"},
11261130
"": {"0.5", "512Mi", "", ""},
11271131
}
11281132

@@ -1600,7 +1604,7 @@ func (g *GKEOrchestrator) buildNodeSelector(schedOpts scheduling.SchedulingOptio
16001604
if nodeSelector == nil {
16011605
nodeSelector = make(map[string]string)
16021606
}
1603-
if strings.Contains(accelLabel, "tpu-v6e") {
1607+
if strings.Contains(accelLabel, "tpu-v6e") || strings.Contains(accelLabel, "tpu7x") {
16041608
nodeSelector["cloud.google.com/gke-tpu-accelerator"] = accelLabel
16051609
} else {
16061610
nodeSelector["cloud.google.com/gke-accelerator"] = accelLabel

0 commit comments

Comments
 (0)