You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Copy file name to clipboardExpand all lines: docs/tutorials/performance_tuning/known_issues.md
+72-40Lines changed: 72 additions & 40 deletions
Display the source diff
Display the rich diff
Original file line number
Diff line number
Diff line change
@@ -1,13 +1,76 @@
1
1
Known Issues
2
2
============
3
3
4
+
## Usage
5
+
4
6
- There might be Python packages having PyTorch as their hard dependency. If you installed `+cpu` version of PyTorch, installation of these packages might replace the `+cpu` version with the default version released on Pypi.org. If anything goes wrong, please reinstall the `+cpu` version back.
5
7
6
8
- If you found the workload runs with Intel® Extension for PyTorch\* occupies a remarkably large amount of memory, you can try to reduce the occupied memory size by setting the `--weights_prepack` parameter of the `ipex.optimize()` function to `False`.
7
9
8
-
- Supporting of EmbeddingBag with INT8 when bag size > 1 is working in progress.
10
+
- If inference is done with a custom function, `conv+bn` folding feature of the `ipex.optimize()` function doesn't work.
11
+
12
+
```
13
+
import torch
14
+
import intel_pytorch_extension as ipex
15
+
16
+
class Module(torch.nn.Module):
17
+
def __init__(self):
18
+
super(Module, self).__init__()
19
+
self.conv = torch.nn.Conv2d(1, 10, 5, 1)
20
+
self.bn = torch.nn.BatchNorm2d(10)
21
+
self.relu = torch.nn.ReLU()
22
+
23
+
def forward(self, x):
24
+
x = self.conv(x)
25
+
x = self.bn(x)
26
+
x = self.relu(x)
27
+
return x
28
+
29
+
def inference(self, x):
30
+
return self.forward(x)
31
+
32
+
if __name__ == '__main__':
33
+
m = Module()
34
+
m.eval()
35
+
m = ipex.optimize(m, dtype=torch.float32, level="O0")
36
+
d = torch.rand(1, 1, 112, 112)
37
+
with torch.no_grad():
38
+
m.inference(d)
39
+
```
40
+
41
+
This is a PyTorch FX limitation. You can avoid this error by calling `m = ipex.optimize(m, level="O0")`, which doesn't apply ipex optimization, or disable `conv+bn` folding by calling `m = ipex.optimize(m, level="O1", conv_bn_folding=False)`.
42
+
43
+
## TorchDynamo
44
+
45
+
- The support of torch.compile() with ipex as the backend is still an experimental feature. If the workload fails to run or demonstrates poor performance, you can use the `torch.jit` APIs and graph optimization APIs of ipex. Currently, the below HuggingFace models fail to run using torch.compile() with ipex backend due to memory issues:
46
+
- masked-language-modeling+xlm-roberta-base
47
+
- casual-language-modeling+gpt2
48
+
- casual-language-modeling+xlm-roberta-base
49
+
- summarization+t5-base
50
+
- text-classification+allenai-longformer-base-409
51
+
52
+
## Dynamic Shape
53
+
54
+
- When working with an NLP model inference with dynamic input data length appling with TorchScript (either `torch.jit.trace` or `torch.jit.script`), performance with Intel® Extension for PyTorch\* is possible to be less than that without Intel® Extension for PyTorch\*. In this case, adding the workarounds below would help solve this issue.
- Compiling with gcc 11 might result in `illegal instruction` error.
67
+
- Low performance withINT8 support for dynamic shapes
68
+
69
+
The support for dynamic shapes in Intel® Extension for PyTorch\* INT8 integration is still work in progress. When the input shapes are dynamic, for example inputs of variable image sizes in an object detection task or of variable sequence lengths in NLP tasks, the Intel® Extension for PyTorch\* INT8 path may slow down the model inference. In this case, use stock PyTorch INT8 functionality.
70
+
71
+
**Note**: Using Runtime Extension feature if batch size cannot be divided by number of streams, because mini batch size on each stream are not equivalent, scripts run into this issues.
72
+
73
+
- Supporting of EmbeddingBag withINT8 when bag size >1is working in progress.
11
74
12
75
-`RuntimeError: Overflow when unpacking long` when a tensor's min max value exceeds int range while performing int8 calibration. Please customize QConfig to use min-max calibration method.
13
76
@@ -31,55 +94,24 @@ Known Issues
31
94
run_benchmark(freezed_model, input)
32
95
```
33
96
97
+
## BFloat16
98
+
34
99
-BF16 AMP(auto-mixed-precision) runs abnormally with the extension on the AVX2-only machine if the topology contains `Conv`, `Matmul`, `Linear`, and`BatchNormalization`
35
100
101
+
## Runtime Extension
102
+
36
103
- Runtime extension of MultiStreamModule doesn't support DLRM inference, since the input of DLRM (EmbeddingBag specifically) can't be simplely batch split.
37
104
38
105
- Runtime extension of MultiStreamModule has poor performance of RNNT Inference comparing with native throughput mode. Only part of the RNNT models (joint_net specifically) can be jit traced into graph. However, in one batch inference, `joint_net`is invoked multi times. It increases the overhead of MultiStreamModule asinput batch split, thread synchronization and output concat.
39
106
107
+
## Correctness
108
+
40
109
- Incorrect Conv and Linear result if the number of OMP threads is changed at runtime
41
110
42
111
The oneDNN memory layout depends on the number of OMP threads, which requires the caller to detect the changes for the # of OMP threads while this release has not implemented it yet.
43
112
44
-
- Low performance with INT8 support for dynamic shapes
45
-
46
-
The support for dynamic shapes in Intel® Extension for PyTorch\* INT8 integration is still work in progress. When the input shapes are dynamic, for example inputs of variable image sizes in an object detection task or of variable sequence lengths in NLP tasks, the Intel® Extension for PyTorch\* INT8 path may slow down the model inference. In this case, use stock PyTorch INT8 functionality.
47
-
48
-
**Note**: Using Runtime Extension feature if batch size cannot be divided by number of streams, because mini batch size on each stream are not equivalent, scripts run into this issues.
113
+
## Float32 Training
49
114
50
115
- Low throughput withDLRMFP32 Train
51
116
52
117
A 'Sparse Add' [PR](https://github.com/pytorch/pytorch/pull/23057) is pending on review. The issue will be fixed when the PRis merged.
53
-
54
-
- If inference is done with a custom function, `conv+bn` folding feature of the `ipex.optimize()` function doesn't work.
55
-
56
-
```
57
-
import torch
58
-
import intel_pytorch_extension as ipex
59
-
60
-
class Module(torch.nn.Module):
61
-
def __init__(self):
62
-
super(Module, self).__init__()
63
-
self.conv = torch.nn.Conv2d(1, 10, 5, 1)
64
-
self.bn = torch.nn.BatchNorm2d(10)
65
-
self.relu = torch.nn.ReLU()
66
-
67
-
def forward(self, x):
68
-
x = self.conv(x)
69
-
x = self.bn(x)
70
-
x = self.relu(x)
71
-
return x
72
-
73
-
def inference(self, x):
74
-
return self.forward(x)
75
-
76
-
if __name__ == '__main__':
77
-
m = Module()
78
-
m.eval()
79
-
m = ipex.optimize(m, dtype=torch.float32, level="O0")
80
-
d = torch.rand(1, 1, 112, 112)
81
-
with torch.no_grad():
82
-
m.inference(d)
83
-
```
84
-
85
-
This is a PyTorch FX limitation. You can avoid this error by calling `m = ipex.optimize(m, level="O0")`, which doesn't apply ipex optimization, or disable `conv+bn` folding by calling `m = ipex.optimize(m, level="O1", conv_bn_folding=False)`.
0 commit comments