You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
-[Saving an inference compatible model checkpoint](#saving-an-inference-compatible-model-checkpoint)
54
+
-[Saving a quantized checkpoint](#saving-a-quantized-checkpoint)
55
+
-[Add the scales to `Linear` layers](#add-the-scales-to-linear-layers)
56
+
-[Update model config](#update-model-config)
47
57
48
-
# fp8: Why?
58
+
# Why?
49
59
50
60
tl;dr:
51
61
@@ -60,20 +70,20 @@ Starting with NVIDIA H100 GPU, GPUs have *hardware support* for 8 bit floating p
60
70
61
71
1. Model takes less GPU ram => more space for kv cache. Modern inference libraries (like vllm/sglang) will have higher/more stable performance with more space for kv cache
62
72
2. Model parameters are half as big => less GPU memory bandwidth
63
-
3. Depending on the GPU, fp8 FLOPS are just higher than bf16 FLOPS. E.g. See [H100 specifications](https://www.nvidia.com/en-us/data-center/h100/); bfloat16 has ~2k teraFLOPS and fp8 has ~4k teraFLOPS
73
+
3. Depending on the GPU, fp8 FLOPS are just higher than `bf16` FLOPS. E.g. See [H100 specifications](https://www.nvidia.com/en-us/data-center/h100/); bfloat16 has ~2k teraFLOPS and fp8 has ~4k teraFLOPS
64
74
65
75
66
-
# fp8: How?
76
+
# How?
67
77
68
78
## Note on executing fp8 models
69
79
70
-
When we talk about fp8 models, we typically only are talking about the **weights being fp8**. The actual execution of the model is still done in `bf16`. So all the **intermediate tensors are still in bf16**, and it's the underlying CUDA kernels that are taking in bf16 tensors and fp8 weights.
80
+
When we talk about `fp8` models, we typically only are talking about the **weights being `fp8`**. The actual execution of the model is still done in `bf16`. So all the **intermediate tensors are still in `bf16`**, and it's the underlying CUDA kernels that are taking in `bf16` tensors and `fp8` weights.
71
81
72
82
**fp8 models still use `bf16` kv cache by default** (since the kv cache stores kv values, which are intermediate tensors).
73
83
74
84
## fp8 bit format
75
85
76
-
There are a number of different fp8 formats; the most common is `float8_e4m3fn`. Here are some facts about it:
86
+
There are a number of different `fp8` formats; the most common is `float8_e4m3fn`. Here are some facts about it:
77
87
78
88
1. This format has `1` sign bit, `4` bits for exponent (`e4`), and `3` bits for mantissa (`m3`)
79
89
2. Values can be between `[-448, +448]`
@@ -101,8 +111,8 @@ And here is how all the representable values are distributed (notice how there a
101
111
102
112
So this leads us with two questions for quantization:
103
113
104
-
1.`bf16` can store values between `[-3.38953e+38, +3.38953e+38]`, how do we fit that into fp8 range of `[-448, +448]`?
105
-
2. How do we take advantage of the distribution of values in fp8?
114
+
1.`bf16` can store values between `[-3.38953e+38, +3.38953e+38]`, how do we fit that into `fp8` range of `[-448, +448]`?
115
+
2. How do we take advantage of the distribution of values in `fp8`?
106
116
107
117
## Quantization - scaling to lower precision loss & handle large values
Above I showed the scale being a single value, but you can also have it be a tensor. If you look at some popular open source fp8 models they typically use this option.
139
+
Above I showed the scale being a single value, but you can also have it be a tensor. If you look at some popular open source `fp8` models they typically use this option.
130
140
131
141
Why would you do this? To theoretically preserve accuracy, though if the values in your tensor are all relatively close together you won't get much benefit.
assert scale.shape == torch.Size([N // n, K // k])
143
153
```
144
154
145
-
# Saving an inference compatible model checkpoint
155
+
# Saving a quantized checkpoint
146
156
147
157
For compatibility with things like VLLM there's a couple things we need to do:
148
158
149
-
1. Add the `weight_scale` as a parameter to each of the `Linear` layers. This basically means just replace the `Linear` layer with this `PackedLinear` class, where `weight` is the `fp8` tensor, and `weight_scale` is the scale.
159
+
## Add the scales to `Linear` layers
160
+
161
+
We need to add the previously computed `weight_scale` as a parameter to each of the `Linear` layers. This basically means just replace the `Linear` layer with this custom `PackedLinear` class, where `weight` is the `fp8` tensor, and `weight_scale` is the scale from previous sections.
150
162
151
163
```python
152
164
classPackedLinear(torch.nn.Module):
@@ -156,7 +168,9 @@ class PackedLinear(torch.nn.Module):
2. Add a `quantization_config` into the model's config. This will also appear in the `config.json` file in the huggingface repo of the model.
171
+
## Update model config
172
+
173
+
This part is really easy, just add a `quantization_config` into the model's config. This will also appear in the `config.json` file in the huggingface repo of the model.
0 commit comments