You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: torchao/float8/README.md
+45-61Lines changed: 45 additions & 61 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -10,15 +10,55 @@ and composable with key systems such as autograd, ```torch.compile``` and distri
10
10
11
11
* e2e pretraining speedups of up to [**1.5x at 512 GPU / 405B parameter count scale**](https://pytorch.org/blog/training-using-float8-fsdp2/),
12
12
and up to [**1.25x at 8 GPU / 8B parameter count scale**](#training-benchmarks), with performance and accuracy validated on up to [**2k GPUs**](https://pytorch.org/blog/accelerating-large-scale-training-and-convergence-with-pytorch-float8-rowwise-on-crusoe-2k-h200s/), via [torchtitan's float8 integration](https://github.com/pytorch/torchtitan/blob/main/docs/float8.md)
13
-
* seamless composability with [torch.compile](https://docs.pytorch.org/docs/stable/torch.compiler.html)
14
-
* seamless composability with [DTensor](https://docs.pytorch.org/docs/stable/distributed.tensor.html), including [FSDP2 with float8 weight all-gather](https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359) and [Async TP](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487)
15
-
* seamless composability with [PyTorch Activation Checkpointing](https://pytorch.org/blog/activation-checkpointing-techniques/)
16
-
* three different scaling recipes to trade off performance vs accuracy: tensorwise (fastest), rowwise, rowwise_with_gw_hp (most accurate)
13
+
* seamless composability with [torch.compile](https://docs.pytorch.org/docs/stable/torch.compiler.html), [DTensor](https://docs.pytorch.org/docs/stable/distributed.tensor.html), [FSDP2 with float8 weight all-gather](https://dev-discuss.pytorch.org/t/enabling-float8-all-gather-in-fsdp2/2359), [Async TP](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487), and [PyTorch AC](https://pytorch.org/blog/activation-checkpointing-techniques/)
14
+
* three recipes to trade off performance vs accuracy: `tensorwise` (fastest), `rowwise`, `rowwise_with_gw_hp` (most accurate)
17
15
* supports both NVIDIA and AMD hardware
18
16
19
17
ℹ️ <em>See the [feature tracker](https://github.com/pytorch/ao/issues/556) for upcoming features.</em>
20
18
21
-
ℹ️ <em>These APIs are training-only and float8-only, and we plan to [unify them with the rest of torchao](https://github.com/pytorch/ao/issues/894) in the future.</em>
19
+
# e2e training benchmarks
20
+
21
+
[Torchtitan](https://github.com/pytorch/torchtitan) was used to benchmark float8 training performance.
22
+
23
+
#### NVIDIA H100
24
+
25
+
- Single-node training on 8xH100 GPUs, batch size 1, sequence length 8192, steps 100, `torch.compile`, FSDP2, per-op SAC
- E2E speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ([example](https://pytorch.org/blog/training-using-float8-fsdp2/)).
48
+
- Rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve.
49
+
50
+
**Reproducing training benchmarks**
51
+
To reproduce these benchmarks, you can follow these steps:
52
+
53
+
1. On a machine with compatible GPUs, clone torchtitan and follow local installation [steps](https://github.com/pytorch/torchtitan?tab=readme-ov-file#installation),
54
+
including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=readme-ov-file#downloading-a-tokenizer).
55
+
2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation).
56
+
3. From the `torchao/benchmarks/float8/training/` directory, you can run the following commands to reproduce the benchmarks above:
- E2E speedups increase as M,K,N (GEMM dimensions) increase. Speedups as high as 1.5x have been measured with larger shapes ([example](https://pytorch.org/blog/training-using-float8-fsdp2/)).
211
-
- Rowwise scaling is better at handling outliers than tensorwise scaling, so these recipes are different points on the accuracy vs performance curve.
212
-
213
-
**Reproducing training benchmarks**
214
-
To reproduce these benchmarks, you can follow these steps:
215
-
216
-
1. On a machine with compatible GPUs, clone torchtitan and follow local installation [steps](https://github.com/pytorch/torchtitan?tab=readme-ov-file#installation),
217
-
including [downloading a tokenizer](https://github.com/pytorch/torchtitan?tab=readme-ov-file#downloading-a-tokenizer).
218
-
2. Install torchao following these [steps](https://github.com/pytorch/ao/tree/main?tab=readme-ov-file#installation).
219
-
3. From the `torchao/benchmarks/float8/training/` directory, you can run the following commands to reproduce the benchmarks above:
See the float8 training benchmarking [guide](.torchao/benchmarks/float8/training/README.md) for more details.
225
-
226
210
# E2E training + inference flow
227
211
228
212
The first step in the E2E is to train your model and save a checkpoint. The second step is to load the checkpoint and optionally apply inference quantization before serving the model.
0 commit comments