Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
for model_size in models:
run_cmds = [
"pip show aqtp",
f"bash MaxText/configs/{tpu}/{model_size}.sh EXECUTABLE=train.py OUTPUT_PATH={base_output_directory} PLATFORM=gke",
f"bash maxtext/configs/tpu/{tpu}/{model_size}.sh EXECUTABLE=train.py OUTPUT_PATH={base_output_directory} PLATFORM=gke",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use src/maxtext/configs as the path instead of maxtext/configs. This will work temporarily, but we will likely remove this soon

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the clarification. We will update the path to src/maxtext/configs for now.

]

tests.extend(
Expand Down
2 changes: 1 addition & 1 deletion dags/examples/maxtext_profile_namegen_example_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@

base_command = (
f"export BASE_OUTPUT_PATH={BASE_OUTPUT_PATH} && "
+ "python3 -m MaxText.train MaxText/configs/base.yml base_output_directory=gs://runner-maxtext-logs run_name=${RUN_NAME} model_name=mixtral-8x7b tokenizer_path=assets/tokenizer.mistral-v1 dataset_path=gs://maxtext-dataset per_device_batch_size=4 enable_checkpointing=false ici_fsdp_parallelism=-1 max_target_length=1024 async_checkpointing=false attention=flash dtype=bfloat16 weight_dtype=bfloat16"
+ "python3 -m MaxText.train maxtext/configs/base.yml base_output_directory=gs://runner-maxtext-logs run_name=${RUN_NAME} model_name=mixtral-8x7b tokenizer_path=assets/tokenizer.mistral-v1 dataset_path=gs://maxtext-dataset per_device_batch_size=4 enable_checkpointing=false ici_fsdp_parallelism=-1 max_target_length=1024 async_checkpointing=false attention=flash dtype=bfloat16 weight_dtype=bfloat16"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A separate PR would be good, but can you please also update the python3 -m MaxText.train references to use python3 -m maxtext.trainers.pre_train.train instead? The old ones are deprecated and will be removed in the near future. This is the full list of commands we will need to update: https://github.com/AI-Hypercomputer/maxtext/tree/102af23138003f20df9e9c8194ee6617e47881f9/src/MaxText

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Understood. We will open a separate PR to migrate those deprecated commands to the new training path.

)

test_models_tpu = {
Expand Down
4 changes: 2 additions & 2 deletions dags/examples/maxtext_profile_sweep_example_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def dict_to_arg(param_dict):
"cluster": XpkClusters.TPU_V6E_256_MLPERF_CLUSTER,
"train_command": [
f"export BASE_OUTPUT_PATH={BASE_OUTPUT_PATH} && "
"python3 -m MaxText.train MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} model_name=mixtral-8x7b "
"python3 -m MaxText.train maxtext/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} model_name=mixtral-8x7b "
# add profiler config: ensure steps > skip_first_n_steps_for_profiler + profiler_steps
"steps=10 profiler=xplane skip_first_n_steps_for_profiler=5 profiler_steps=3 "
+ dict_to_arg({
Expand Down Expand Up @@ -81,7 +81,7 @@ def dict_to_arg(param_dict):
"base_output_directory": "gs://runner-maxtext-logs",
"train_command": [
f"export BASE_OUTPUT_PATH={BASE_OUTPUT_PATH} && "
"python3 -m MaxText.train MaxText/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} model_name=mixtral-8x7b "
"python3 -m MaxText.train maxtext/configs/base.yml base_output_directory=${BASE_OUTPUT_PATH} model_name=mixtral-8x7b "
# add profiler config: ensure steps > skip_first_n_steps_for_profiler + profiler_steps
"steps=10 profiler=xplane skip_first_n_steps_for_profiler=5 profiler_steps=3 "
+ dict_to_arg({
Expand Down
2 changes: 1 addition & 1 deletion dags/examples/maxtext_sweep_gke_example_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
# MaxText set up and run commands
base_output_directory = "gs://maxtext-experiments-multipod"
base_run_model_cmds = [
f"python3 -m MaxText.train MaxText/configs/base.yml base_output_directory={base_output_directory} dataset_path=gs://max-datasets-rogue enable_checkpointing=false global_parameter_scale=16 steps=10",
f"python3 -m MaxText.train maxtext/configs/base.yml base_output_directory={base_output_directory} dataset_path=gs://max-datasets-rogue enable_checkpointing=false global_parameter_scale=16 steps=10",
]

# Get list of MaxText GKE XPK jobs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def get_config(
f"export KV_QUANT_AXIS={model_configs['kv_quant_axis']}",
# Start JetStream MaxText server in the background
"""python3 -m MaxText.maxengine_server \
MaxText/configs/inference_jetstream.yml \
maxtext/configs/inference/inference_jetstream.yml \
model_name=${MODEL_NAME} \
tokenizer_path=${TOKENIZER_PATH} \
weight_dtype=${WEIGHT_DTYPE} \
Expand Down
2 changes: 1 addition & 1 deletion dags/inference/configs/maxtext_gpu_inference_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_maxtext_gpu_inference_config(
"export XLA_PYTHON_CLIENT_MEM_FRACTION=0.94",
"export PER_DEVICE_BATCH_SIZE=190",
"export PYTHONPATH='/opt/maxtext:${PYTHONPATH:+:$PYTHONPATH}'",
f"python3 -m MaxText.inference_microbenchmark MaxText/configs/base.yml base_output_directory=$BASE_OUTPUT_DIRECTORY model_name='llama2-70b' max_prefill_predict_length=1024 max_target_length=2048 attention=dot_product scan_layers=false hardware=gpu async_checkpointing=$ASYNC_CHECKPOINTING per_device_batch_size=$PER_DEVICE_BATCH_SIZE inference_microbenchmark_prefill_lengths=1024 inference_microbenchmark_stages=prefill,generate inference_microbenchmark_loop_iters=64 run_name=$(date +%Y-%m-%d-%H-%M) ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 weight_dtype=bfloat16 kv_quant_dtype=fp8 quantize_kvcache=True quantization=aqt_fp8 > output.txt",
f"python3 -m MaxText.inference_microbenchmark maxtext/configs/base.yml base_output_directory=$BASE_OUTPUT_DIRECTORY model_name='llama2-70b' max_prefill_predict_length=1024 max_target_length=2048 attention=dot_product scan_layers=false hardware=gpu async_checkpointing=$ASYNC_CHECKPOINTING per_device_batch_size=$PER_DEVICE_BATCH_SIZE inference_microbenchmark_prefill_lengths=1024 inference_microbenchmark_stages=prefill,generate inference_microbenchmark_loop_iters=64 run_name=$(date +%Y-%m-%d-%H-%M) ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_tensor_parallelism=-1 weight_dtype=bfloat16 kv_quant_dtype=fp8 quantize_kvcache=True quantization=aqt_fp8 > output.txt",
"wget https://raw.githubusercontent.com/GoogleCloudPlatform/ml-auto-solutions/refs/heads/master/dags/inference/utils/maxtext_gpu_microbenchmark_jsonl_converter.py",
f"python maxtext_gpu_microbenchmark_jsonl_converter.py {jsonl_output_path}",
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def config(
# Configure flags
"export XLA_FLAGS='--xla_disable_hlo_passes=rematerialization'",
f"""python3 -m MaxText.inference_microbenchmark_sweep \
MaxText/configs/base.yml \
maxtext/configs/base.yml \
model_name={model_configs['model_name']} \
tokenizer_path=assets/{model_configs['tokenizer']} \
weight_dtype={model_configs['weight_dtype']} \
Expand Down
4 changes: 2 additions & 2 deletions dags/multipod/configs/gke_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def get_gke_maxtext_nightly_config(
(
"JAX_PLATFORM_NAME=TPU XLA_FLAGS='--xla_dump_to=/tmp/xla_dump/'"
" ENABLE_PJRT_COMPATIBILITY=true"
f" python3 -m MaxText.train MaxText/configs/base.yml run_name={run_name}"
f" python3 -m MaxText.train maxtext/configs/base.yml run_name={run_name}"
f" base_output_directory={base_output_directory}"
" dataset_path=gs://max-datasets-rogue dataset_type=synthetic"
" model_name=llama3-8b per_device_batch_size=12 reuse_example_batch=1 metrics_file='metrics.txt'"
Expand Down Expand Up @@ -213,7 +213,7 @@ def get_gke_gpt3_6b_nightly_config(
(
"JAX_PLATFORM_NAME=TPU XLA_FLAGS='--xla_dump_to=/tmp/xla_dump/'"
" ENABLE_PJRT_COMPATIBILITY=true"
f" python3 -m MaxText.train MaxText/configs/base.yml run_name={run_name} model_name=gpt3-6b"
f" python3 -m MaxText.train maxtext/configs/base.yml run_name={run_name} model_name=gpt3-6b"
f" base_output_directory={base_output_directory}"
" dataset_path=gs://max-datasets-rogue dataset_type=synthetic"
" per_device_batch_size=12 reuse_example_batch=1 global_parameter_scale=1 metrics_file='metrics.txt'"
Expand Down
2 changes: 1 addition & 1 deletion dags/multipod/configs/maxtext_gce_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def get_maxtext_nightly_config(
"cd /tmp/maxtext &&"
" JAX_PLATFORM_NAME=TPU XLA_FLAGS='--xla_dump_to=/tmp/xla_dump/'"
" ENABLE_PJRT_COMPATIBILITY=true"
f" python3 -m MaxText.train MaxText/configs/base.yml run_name={run_name}"
f" python3 -m MaxText.train maxtext/configs/base.yml run_name={run_name}"
f" base_output_directory={base_output_directory}"
" dataset_path=gs://max-datasets-rogue dataset_type=synthetic"
" per_device_batch_size=12 reuse_example_batch=1 global_parameter_scale=1 metrics_file='metrics.txt'"
Expand Down
2 changes: 1 addition & 1 deletion dags/multipod/maxtext_configs_aot.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@
output_path = "dummy-output-dir"

cmd = (
f"bash src/MaxText/configs/{tpu}/{model_size}.sh "
f"bash src/maxtext/configs/tpu/{tpu}/{model_size}.sh "
"EXECUTABLE=train_compile "
f"M_COMPILE_TOPOLOGY={tpu}-{num_cores} "
f"M_COMPILE_TOPOLOGY_NUM_SLICES={n} "
Expand Down
2 changes: 1 addition & 1 deletion dags/multipod/maxtext_configs_aot_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
)

# GPU AoT tests
cmd = "bash src/MaxText/configs/a3/llama_2_7b/8vm.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=a3 M_COMPILE_TOPOLOGY_NUM_SLICES=8"
cmd = "bash src/maxtext/configs/gpu/a3/llama_2_7b/8vm.sh EXECUTABLE=train_compile M_COMPILE_TOPOLOGY=a3 M_COMPILE_TOPOLOGY_NUM_SLICES=8"
stable_a3_gpu = gke_config.get_maxtext_end_to_end_gpu_gke_test_config(
time_out_in_min=300,
test_name="maxtext-aot-a3-stable",
Expand Down
2 changes: 1 addition & 1 deletion dags/multipod/maxtext_configs_aot_hybridsim.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def hybridsim_compile_and_run(group_id):
' --xla_dump_large_constants"'
),
(
f"bash src/MaxText/configs/v{tpu_version_str}/{model_size}.sh"
f"bash src/maxtext/configs/tpu/v{tpu_version_str}/{model_size}.sh"
" EXECUTABLE=train_compile"
f" M_COMPILE_TOPOLOGY=v{tpu_version_str}-{num_cores}"
f" M_COMPILE_TOPOLOGY_NUM_SLICES={n}"
Expand Down
4 changes: 2 additions & 2 deletions dags/multipod/maxtext_gpu_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@ def run_maxtext_tests(dag: models.DAG):
timestamp = datetime.datetime.now().strftime("%Y-%m-%d-%H-%M")
train_base = (
"XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
"python3 -m MaxText.train MaxText/configs/base.yml "
"python3 -m MaxText.train maxtext/configs/base.yml "
"base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset "
"steps=2 enable_checkpointing=false attention=dot_product"
)
decode_base = (
"XLA_PYTHON_CLIENT_MEM_FRACTION=0.65 TF_FORCE_GPU_ALLOW_GROWTH=true "
"python3 -m MaxText.decode MaxText/configs/base.yml "
"python3 -m MaxText.decode maxtext/configs/base.yml "
"base_output_directory=gs://runner-maxtext-logs dataset_path=gs://maxtext-dataset "
"steps=2 enable_checkpointing=false attention=dot_product "
"max_target_length=128 per_device_batch_size=1"
Expand Down
2 changes: 1 addition & 1 deletion dags/multipod/maxtext_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
for mode, image in docker_images:
profiling_cmds = (
f"export RUN_NAME=profiling_{mode.value}_$(date +%Y-%m-%d-%H-%M-%S)",
"python3 -m MaxText.train MaxText/configs/base.yml"
"python3 -m MaxText.train maxtext/configs/base.yml"
f" run_name=$RUN_NAME base_output_directory={base_output_directory}"
f" dataset_path={dataset_path} profiler=xplane steps=20",
f"gcloud storage cp --recursive {base_output_directory}/$RUN_NAME/tensorboard .",
Expand Down
2 changes: 1 addition & 1 deletion dags/multipod/maxtext_profiling_vertex_ai_tensorboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
f"-{accelerator}-{current_datetime}"
),
(
"python3 -m MaxText.train MaxText/configs/base.yml"
"python3 -m MaxText.train maxtext/configs/base.yml"
f" run_name=$RUN_NAME base_output_directory"
f"={base_output_directory} dataset_path={dataset_path}"
f" profiler=xplane steps=10"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@
time_out_in_min=60,
run_model_cmds=(
f"JAX_PLATFORMS=tpu,cpu ENABLE_PJRT_COMPATIBILITY=true TPU_SLICE_BUILDER_DUMP_CHIP_FORCE=true TPU_SLICE_BUILDER_DUMP_ICI=true JAX_FORCE_TPU_INIT=true ENABLE_TPUNETD_CLIENT=true && "
f"python -m MaxText.train MaxText/configs/base.yml run_name={slice_num}slice-V{cluster.device_version}_{cores}-maxtext-jax-stable-stack-{current_datetime} "
f"python -m MaxText.train maxtext/configs/base.yml run_name={slice_num}slice-V{cluster.device_version}_{cores}-maxtext-jax-stable-stack-{current_datetime} "
"steps=30 per_device_batch_size=1 max_target_length=4096 model_name=llama2-7b "
"enable_checkpointing=false attention=dot_product remat_policy=minimal_flash use_iota_embed=true scan_layers=false "
"dataset_type=synthetic async_checkpointing=false "
Expand Down
Loading