Skip to content

Commit de48098

Browse files
xiaolil1tye1
andauthored
Add tutorial for FP8 (#3537)
* Add tutorial for FP8 --------- Co-authored-by: Ye Ting <[email protected]>
1 parent bf6498f commit de48098

File tree

1 file changed

+44
-0
lines changed

1 file changed

+44
-0
lines changed

docs/tutorials/features/float8.md

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
Float8 datatype support [GPU] (Experimental)
2+
============================================
3+
4+
## Float8 DataType
5+
6+
Float8 (FP8) is 8-bit floating point which is used to reduce memory footprint, improve the computation efficiency and save power in Deep Learning domain.
7+
8+
Two formats are used in FP8 training and inference, in order to meet the required value range and precision of activation, weight and gradient in Deep Neural Network (DNN). One is E4M3 (sign-exponent-mantissa) for activation and weight, the other is E5M2 for gradients. These two formats are defined in [FP8 FORMATS FOR DEEP LEARNING](https://arxiv.org/pdf/2209.05433.pdf).
9+
10+
FP8 data type is used for memory storage only in current stage. It will be converted to BFloat16 data type for computation.
11+
12+
## FP8 Quantization
13+
14+
On GPU, online Dynamic Quantization is used for FP8 data compression and decompression. Delayed Scaling algorithm is used for accelerating the quantizaiton process.
15+
16+
## Supported running mode
17+
18+
Both DNN Training and Inference are supported with FP8 data type.
19+
20+
## Supported operators
21+
22+
FP8 Linear operator is supported.
23+
24+
## FP8 usage example
25+
26+
BERT model is supported as a FP8 training showcase, see the following example:
27+
28+
```python
29+
from intel_extension_for_pytorch.xpu.fp8.fp8 import fp8_autocast
30+
from intel_extension_for_pytorch.xpu.fp8.recipe import DelayedScaling
31+
from intel_extension_for_pytorch.nn.utils._fp8_convert import convert_fp8_model
32+
33+
## AMP is optionally to be used for FP8
34+
with torch.xpu.amp.autocast(enabled=True, dtype=optimize_dtype):
35+
## 'fp8_autocase' is the handler of FP8
36+
with fp8_autocast(enabled=True, fp8_recipe=DelayedScaling()):
37+
## The original model will be automatically converted to a new model with FP8 operators with 'convert_fp8_model'
38+
convert_fp8_model(model)
39+
outputs = model(input_ids=input_ids,
40+
token_type_ids=segment_ids,
41+
attention_mask=input_mask,
42+
labels=masked_lm_labels,
43+
next_sentence_label=next_sentence_labels)
44+
```

0 commit comments

Comments
 (0)