Skip to content

Commit 853f87d

Browse files
authored
torchao.float8: update with AMD MI300X benchmark results (#2736)
I got a devgpu with 8 AMD MI300X GPUs, ran the torchtitan benchmarks (without any performance debugging), and adding the numbers I saw to the readme. The tensorwise number looks lower than expected, we can debug/fix this in a future PR.
1 parent 948ade1 commit 853f87d

File tree

1 file changed

+22
-15
lines changed

1 file changed

+22
-15
lines changed

torchao/float8/README.md

Lines changed: 22 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ and up to [**1.25x at 8 GPU / 8B parameter count scale**](#training-benchmarks),
1414
* seamless composability with [DTensor](https://docs.pytorch.org/docs/stable/distributed.tensor.html), including [FSDP2 with float8 weight all-gather](https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359) and [Async TP](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487)
1515
* seamless composability with [PyTorch Activation Checkpointing](https://pytorch.org/blog/activation-checkpointing-techniques/)
1616
* three different scaling recipes to trade off performance vs accuracy: tensorwise (fastest), rowwise, rowwise_with_gw_hp (most accurate)
17+
* supports both NVIDIA and AMD hardware
1718

1819
ℹ️ <em>See the [feature tracker](https://github.com/pytorch/ao/issues/556) for upcoming features.</em>
1920

@@ -186,22 +187,28 @@ python test/float8/test_fsdp2/test_fsdp2.py
186187
[Torchtitan](https://github.com/pytorch/torchtitan) was used to benchmark float8 training performance, for both rowwise
187188
and tensorwise scaling. The training benchmarks were all run using:
188189

189-
- Single-node training on 8xH100 GPUs
190-
- Batch size 1
191-
- Sequence length 8192
192-
- Steps 100
193-
- `torch.compile`
194-
- FSDP2
195-
- pytorch version: `2.7.0a0+gitb98af95`
196-
- torchao version: `0.10.0+git890e0ac8`
197-
- torchtitan version: `0.0.2`
190+
#### NVIDIA H100
198191

192+
- Single-node training on 8xH100 GPUs, batch size 1, sequence length 8192, steps 100, `torch.compile`, FSDP2, per-op SAC
193+
- pytorch version: `2.7.0a0+gitb98af95`, torchao version: `0.10.0+git890e0ac8`, torchtitan version: `0.0.2`
199194

200-
| Model | Scaling | Activation checkpointing | Peak Memory (GB) | Median tokens/second | Speedup over baseline
201-
| ------------- | ---------------------------------- | ------------------------ | ------------------| -------------------- | ---------------------
202-
| Llama3-8b | none (bfloat16) | per op SAC | 47.65 | 6150 | -
203-
| Llama3-8b | tensorwise with float8 all-gather | per op SAC | 47.77 | 7689.5 | 25.03%
204-
| Llama3-8b | rowwise with bfloat16 all-gather | per op SAC | 47.79 | 6768 | 10.05%
195+
| Model | Scaling | Peak Memory (GB) | Median tokens/second | Speedup over baseline
196+
| ------------- | ---------------------------------- | ------------------| -------------------- | ---------------------
197+
| Llama3-8b | none (bfloat16) | 47.65 | 6150 | -
198+
| Llama3-8b | tensorwise with float8 all-gather | 47.77 | 7689.5 | 25.03%
199+
| Llama3-8b | rowwise with bfloat16 all-gather | 47.79 | 6768 | 10.05%
200+
201+
#### AMD MI300x
202+
203+
- Single-node training on 8xMI300X GPUs, batch size 1, sequence length 8192, steps 100, `torch.compile`, FSDP2, per-op SAC
204+
- pytorch version: `2.9.0.dev20250811+rocm6.4`, torchao version `0.13.0+git4fc4068d6`, torchtitan commit `2c8b5947991239913d67e2f7d22a255c3e2a9694`
205+
206+
| Model | Scaling | Peak Memory (GB) | Median tokens/second | Speedup over baseline
207+
| ------------- | ---------------------------------- | ------------------| -------------------- | ---------------------
208+
| Llama3-8b | none (bfloat16) | 39.09 | 5376.5 | -
209+
| Llama3-8b | tensorwise with float8 all-gather | 39.07 | 6166.0 | 14.68%
210+
| Llama3-8b | rowwise_with_gw_hp with bfloat16 all-gather | 39.32 | 6100.0 | 13.46%
211+
| Llama3-8b | rowwise with bfloat16 all-gather | 39.32 | 5891.0 | 9.57%
205212

206213
**Important notes**:
207214
- E2E speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ([example](https://pytorch.org/blog/training-using-float8-fsdp2/)).
@@ -210,7 +217,7 @@ and tensorwise scaling. The training benchmarks were all run using:
210217
**Reproducing training benchmarks**
211218
To reproduce these benchmarks, you can follow these steps:
212219

213-
1. On a machine with 8 H100 GPUs, clone torchtitan and follow local installation [steps](https://github.com/pytorch/torchtitan?tab=readme-ov-file#installation),
220+
1. On a machine with compatible GPUs, clone torchtitan and follow local installation [steps](https://github.com/pytorch/torchtitan?tab=readme-ov-file#installation),
214221
including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=readme-ov-file#downloading-a-tokenizer).
215222
2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation).
216223
3. From the `torchao/benchmarks/float8/training/` directory, you can run the following commands to reproduce the benchmarks above:

0 commit comments

Comments
 (0)