Skip to content

Commit d8bb51f

Browse files
authored
Update mx_formats README.md (#2777)
* Update mx_formats README.md * Update README.md
1 parent 49cb18a commit d8bb51f

File tree

1 file changed

+26
-6
lines changed

1 file changed

+26
-6
lines changed

torchao/prototype/mx_formats/README.md

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,35 @@ in native PyTorch. We are currently in prototype and are actively working on op
77

88
| workflow | emulation | performance | accuracy |
99
| --- | --- | --- | --- |
10-
| training with mxfp8 || 🚧 [active development](https://github.com/pytorch/ao/issues/1768) ||
11-
| inference (weight-only) with mxfp8, mxfp6, mxfp4 || 🔲 | 🔲 |
12-
13-
We plan to add the following features in the near future:
14-
* other inference workflows such as dynamic quantization
15-
* a unified training to inference workflow
10+
| training with mxfp8 ||||
11+
| inference with mxfp8, mxfp6, mxfp4 || 🔲 | 🔲 |
1612

1713
ℹ️ <em>See the [feature tracker](https://github.com/pytorch/ao/issues/556) and the [performance tracker](https://github.com/pytorch/ao/issues/1768) for upcoming features.</em>
1814

15+
## Training e2e benchmarks on NVIDIA B200
16+
17+
- Single-node training on 8xB200 GPUs limited to 750W, batch size 1, sequence length 8192, steps 100, `torch.compile`, FSDP2, per-op SAC
18+
- pytorch version: `2.9.0.dev20250815+cu128`, torchao version: `0.13.0+gite4e681be6`, torchtitan commit: `6fc499f6f5b32151a799188be2208cfb09faed30`
19+
20+
| Model | Scaling | Peak Memory (GB) | Median tokens/second | Speedup over baseline
21+
| ------------- | ---------------------------------- | ------------------| -------------------- | ---------------------
22+
| Llama3-8b | none (bfloat16) | 33.71 | 8307.5 | -
23+
| Llama3-8b | float8 tensorwise (f8 all-gather) | 33.38 | 10417.0 | 25.4%
24+
| Llama3-8b | mxfp8_cublas | 33.88 | 9969.0 | 20.0%
25+
| Llama3-8b | mxfp8_cublas_rceil | 33.88 | 9642.0 | 16.1%
26+
| Llama3-8b | float8 rowwise | 33.72 | 8640.5 | 4.0%
27+
28+
**Reproducing training benchmarks**
29+
To reproduce these benchmarks, you can follow these steps:
30+
31+
1. On a machine with compatible GPUs, clone torchtitan and follow local installation [steps](https://github.com/pytorch/torchtitan?tab=readme-ov-file#installation),
32+
including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=readme-ov-file#downloading-a-tokenizer).
33+
2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation).
34+
3. From the `torchao/` directory, you can run the following commands to reproduce the benchmarks above:
35+
- bf16 + compile: `TORCHTITAN_ROOT=<path> ./benchmarks/float8/training/llama3.sh`
36+
- mxfp8_cublas: `TORCHTITAN_ROOT=<path> MX_RECIPE="mxfp8_cublas" ./benchmarks/float8/training/llama3.sh`
37+
- mxfp8_cublas_rceil: `TORCHTITAN_ROOT=<path> MX_RECIPE="mxfp8_cublas_rceil" ./benchmarks/float8/training/llama3.sh`
38+
1939
# User API
2040

2141
## MX training

0 commit comments

Comments
 (0)