Skip to content

Commit 49cb18a

Browse files
authored
make e2e training benchmark support mx (#2776)
Update [ghstack-poisoned]
1 parent 9192799 commit 49cb18a

File tree

2 files changed

+15
-7
lines changed

2 files changed

+15
-7
lines changed

benchmarks/float8/training/llama3.sh

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,30 @@ LOG_FILE="/tmp/float8_training_log.txt"
1717
# validate user has specified torchtitan root directory
1818
if [ -z "${TORCHTITAN_ROOT}" ]; then
1919
echo "Error: TORCHTITAN environment variable is not set. Please set it before running this script."
20-
echo "Usage: TORCHTITAN_ROOT=<directory> ./float8_training_benchmark.sh"
20+
echo "Usage: TORCHTITAN_ROOT=<directory> ./llama3.sh"
2121
echo "Optional parameters configurable via environment variables:"
2222
echo " * FLOAT8_RECIPE_WITH_BEST_SETTINGS: "rowwise" or "tensorwise". if set, use float8 training in torchtitan with the specified recipe, including the additional settings which are optimal for that recipe. otherwise, use bf16 mixed precision training."
23+
echo " * MX_RECIPE: any valid MX recipe name. Note: only one of FLOAT8_RECIPE_WITH_BEST_SETTINGS and MX_RECIPE can be set."
2324
echo " * LOCAL_BATCH_SIZE: defaults to 1."
2425
echo " * STEPS: defaults to 100."
2526
echo " * EXTRA_ARGS: additional arguments to pass to the torchtitan training script."
2627
exit 1
2728
fi
2829

2930
# validate recipe name
30-
if [ -n "${FLOAT8_RECIPE_WITH_BEST_SETTINGS}" ]; then
31+
if [ -n "${FLOAT8_RECIPE_WITH_BEST_SETTINGS}" ] && [ -n "${MX_RECIPE}" ]; then
32+
echo "Error: both FLOAT8_RECIPE_WITH_BEST_SETTINGS and MX_RECIPE are set, please only set one of them." >&2
33+
exit 1
34+
elif [ -n "${FLOAT8_RECIPE_WITH_BEST_SETTINGS}" ]; then
3135
if [ "${FLOAT8_RECIPE_WITH_BEST_SETTINGS}" == "tensorwise" ]; then
3236
FLOAT8_ARGS="--model.converters="float8" --float8.enable_fsdp_float8_all_gather --float8.precompute_float8_dynamic_scale_for_fsdp"
3337
else
3438
FLOAT8_ARGS="--model.converters="float8" --float8.recipe_name=${FLOAT8_RECIPE_WITH_BEST_SETTINGS}"
3539
fi
40+
elif [ -n "${MX_RECIPE}" ]; then
41+
FLOAT8_ARGS="--model.converters="mx" --mx.recipe_name=${MX_RECIPE}"
42+
else
43+
FLOAT8_ARGS=""
3644
fi
3745

3846

@@ -51,7 +59,7 @@ CONFIG_FILE="./torchtitan/models/llama3/train_configs/llama3_8b.toml" ${TORCHTIT
5159
cd $original_dir
5260

5361
# parse logs to calculate top line metrics
54-
python parse_torchtitan_logs.py --log-file ${LOG_FILE}
62+
python benchmarks/float8/training/parse_torchtitan_logs.py --log-file ${LOG_FILE}
5563

5664
# clean up logs
5765
rm ${LOG_FILE}

torchao/float8/README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ To reproduce these benchmarks, you can follow these steps:
5353
1. On a machine with compatible GPUs, clone torchtitan and follow local installation [steps](https://github.com/pytorch/torchtitan?tab=readme-ov-file#installation),
5454
including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=readme-ov-file#downloading-a-tokenizer).
5555
2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation).
56-
3. From the `torchao/benchmarks/float8/training/` directory, you can run the following commands to reproduce the benchmarks above:
57-
- bf16 + compile: `TORCHTITAN_ROOT=<path> ./torchtitan_benchmark.sh`
58-
- float8 tensorwise with float8 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="tensorwise" ./torchtitan_benchmark.sh`
59-
- float8 rowwise with bf16 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./torchtitan_benchmark.sh`
56+
3. From the `torchao/` directory, you can run the following commands to reproduce the benchmarks above:
57+
- bf16 + compile: `TORCHTITAN_ROOT=<path> ./benchmarks/float8/training/llama3.sh`
58+
- float8 tensorwise with float8 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="tensorwise" ./benchmarks/float8/training/llama3.sh`
59+
- float8 rowwise with bf16 all-gather + compile: `TORCHTITAN_ROOT=<path> FLOAT8_RECIPE_WITH_BEST_SETTINGS="rowwise" ./benchmarks/float8/training/llama3.sh`
6060

6161
See the float8 training benchmarking [guide](.torchao/benchmarks/float8/training/README.md) for more details.
6262

0 commit comments

Comments
 (0)