Skip to content

Commit 35ab427

Browse files
authored
[DOC] Add docs for TIMM and TorchBench setup (#125)
* Add instructions for TIMM and TorchBench * Update end_to_end_tests.md * Update
1 parent da1bc1f commit 35ab427

File tree

1 file changed

+155
-45
lines changed

1 file changed

+155
-45
lines changed

docs/test_docs/end_to_end_tests.md

Lines changed: 155 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,116 @@
11
- [Overview](#overview)
22
- [Pre-Request](#pre-request)
3-
- [Use the Hugging Face model](#use-the-hugging-face-model)
4-
- [TL;DR](#tldr)
5-
- [Detail for commands](#detail-for-commands)
6-
- [Debugging Tips](#debugging-tips)
7-
- [Profiling](#profiling)
3+
- [Package Installation](#package-installation)
4+
- [HuggingFace and TIMM Models Installation](#huggingface-and-timm-models-installation)
5+
- [TorchBench Installation](#torchbench-installation)
6+
- [Install Torch Vision](#install-torch-vision)
7+
- [Install Torch Text](#install-torch-text)
8+
- [Install Torch Audio](#install-torch-audio)
9+
- [Install TorchBenchmark](#install-torchbenchmark)
10+
- [Run the Model](#run-the-model)
11+
- [Command Details](#command-details)
12+
- [Debugging Tips](#debugging-tips)
13+
- [Profiling](#profiling)
14+
- [Option 1 : Use Legacy Profiling](#option-1--use-legacy-profiling)
15+
- [Profiling Settings](#profiling-settings)
16+
- [Option 2: Use Kineto Profiling](#option-2-use-kineto-profiling)
17+
- [Profiling Settings](#profiling-settings-1)
818
- [End-to-end Tests Setting:](#end-to-end-tests-setting)
9-
- [Profiling Settings](#profiling-settings)
1019
- [Profiling Tips](#profiling-tips)
1120

1221

1322
# Overview
14-
This doc contains [Torchdynamo Benchmarks](https://github.com/pytorch/pytorch/tree/main/benchmarks/dynamo) setup for XPU Backend for Triton\*.
23+
This document outlines the setup for [Torchdynamo Benchmarks](https://github.com/pytorch/pytorch/tree/main/benchmarks/dynamo) with XPU Backend for Triton*. It includes various suites and serves as a common frontend usage guide.
24+
25+
The Benchmark contains different suites and shares as a common frontend usage. This doc below is an example showing [Hugging Face\*](https://huggingface.co/), [TIMM Models](https://github.com/rwightman/pytorch-image-models) and [TorchBench](https://github.com/pytorch/benchmark) End-to-End models within the [Torchdynamo Benchmarks](https://github.com/pytorch/pytorch/tree/main/benchmarks/dynamo) context.
1526

16-
The Benchmark contains different suites and shares as a common frontend usage. This doc below is an example showing [Hugging Face\*](https://huggingface.co/) End-to-End models for triton.
1727

1828
# Pre-Request
1929
The PyTorch version should be the same as the one in [installation guide for intel_extension_for_pytorch](https://intel.github.io/intel-extension-for-pytorch/xpu/latest/tutorials/installation.html#installation-guide).
2030

31+
# Package Installation
32+
## HuggingFace and TIMM Models Installation
33+
The scripts on [Torchdynamo Benchmarks](https://github.com/pytorch/pytorch/tree/main/benchmarks/dynamo) automatically download and install the transformers and timm packages. However, there are instances where the script may uninstall the XPU version of PyTorch and install the CUDA version instead. Therefore, verifying the PyTorch version before running is crucial.
2134

22-
# Use the Hugging Face model
35+
```Bash
36+
# Wrong one, it uses CUDA version
37+
(triton_env) ➜ python
38+
>>> import torch
39+
>>> torch.__version__
40+
'2.1.0+cu121'
41+
>>> torch.__file__
42+
'/home/user/miniconda3/envs/triton_env/lib/python3.10/site-packages/torch/__init__.py'
43+
44+
# Correct one, should use XPU
45+
>>> import torch
46+
>>> torch.__version__
47+
'2.1.0a0+gitdd9913f'
48+
>>> torch.__file__
49+
'/home/user/pytorch/torch/__init__.py'
50+
```
51+
If the PyTorch version is incorrect, please reinstall the [XPU version of PyTorch](https://intel.github.io/intel-extension-for-pytorch/xpu/latest/tutorials/installation.html#installation-guide).
2352

24-
## TL;DR
25-
PyTorch benchmark will automatically download necessary dependencies.
2653

27-
Simply run the model using the following sh file. Note that there are some tricks for debugging. It is recommended to refer to [Debugging Tips](#debugging-tips).
54+
## TorchBench Installation
55+
TorchBench relies on [torchvision](https://github.com/pytorch/vision.git),[torchtext](https://github.com/pytorch/text) and [torchaudio](https://github.com/pytorch/audio.git). Since it by default build with CUDA support, for XPU support, all of these packages needs to be **BUILD FROM SOURCE**.
56+
57+
Please follow the following command for building and installation dependencies:
58+
59+
60+
### Install Torch Vision
61+
62+
```Bash
63+
git clone --recursive https://github.com/pytorch/vision.git
64+
cd vision
65+
conda install libpng jpeg
66+
conda install -c conda-forge ffmpeg
67+
python setup.py install
68+
```
69+
### Install Torch Text
2870

71+
```Bash
72+
git clone --recursive https://github.com/pytorch/text
73+
cd text
74+
python setup.py clean install
75+
```
2976

77+
Note that when building, it has the following error, it could be ignored.
3078

31-
First, copy the sh file [intel_xpu_backend/.github/scripts/inductor_xpu_test.sh](../../.github/scripts/inductor_xpu_test.sh) to the PyTorch source folder, then run the `sh` file with the command:
79+
```Bash
80+
Processing dependencies for torchtext==0.17.0a0+c0d0685
81+
error: torch 2.1.0a0+gitdd9913f is installed but torch==2.1.0 is required by {'torchdata'}
82+
```
83+
84+
### Install Torch Audio
85+
```Bash
86+
pip install torchaudio
87+
git clone --recursive https://github.com/pytorch/audio.git
88+
cd audio
89+
python setup.py install
90+
```
91+
92+
### Install TorchBenchmark
93+
Ensure all dependencies are correctly installed:
94+
95+
```Bash
96+
python -c "import torchvision,torchtext,torchaudio;print(torchvision.__version__, torchtext.__version__, torchaudio.__version__)"
97+
```
98+
99+
Then install TorchBenchmark as a library:
100+
```
101+
conda install git-lfs pyyaml pandas scipy psutil
102+
git clone --recursive https://github.com/pytorch/benchmark.git
103+
104+
cd benchmark
105+
python install.py
106+
pip install .
107+
```
108+
109+
# Run the Model
110+
Simply run the model using the following sh file. Note that there are some tricks for debugging. It is recommended to refer to [Debugging Tips](#debugging-tips).
111+
112+
113+
Copy the shell script [intel_xpu_backend/.github/scripts/inductor_xpu_test.sh](../../.github/scripts/inductor_xpu_test.sh) to the PyTorch source folder, then execute the command:
32114

33115
```Bash
34116
# Run all models
@@ -38,12 +120,15 @@ bash xpu_run_batch.sh huggingface amp_bf16 training performance xpu 0
38120
bash xpu_run_batch.sh huggingface amp_bf16 training performance xpu 0 static 1 0 T5Small
39121
```
40122

41-
There are also useful env flag, for example:
42-
- `TORCHINDUCTOR_CACHE_DIR={some_DIR}`: Where the cache files are put. It is useful when debugging.
43-
- `TORCH_COMPILE_DEBUG=1`: Whether print debug info.
44-
- `TRITON_XPU_PROFILE=ON`: Show XPU triton kernels for debug.
123+
For the real example, refer to our CI command at [triton_xpu_backend_e2e_nightly.yml](https://github.com/intel/intel-xpu-backend-for-triton/blob/da1bc1fb7a39cb3c3332a92fba47c2fc1df25396/.github/workflows/triton_xpu_backend_e2e_nightly.yml#L230-L233).
124+
125+
126+
Environment variables for debugging include:
127+
- `TORCHINDUCTOR_CACHE_DIR={some_DIR}`: Specifies the cache directory. Useful for debugging.
128+
- `TORCH_COMPILE_DEBUG=1`: Enables debug information printing.
129+
- `TRITON_XPU_PROFILE=ON`: Displays XPU Triton kernels for debugging.
45130

46-
By default, the cache dir is under `/tmp/torchinductor_{user}/`, it is recommended to change the cache dir to a new place when you are debugging. For example,
131+
By default, the cache dir is under `/tmp/torchinductor_{user}/`, It's advisable to change this when debugging, as demonstrated below:
47132

48133
```Bash
49134
LOG_DIR=${WORKSPACE}/inductor_log/${SUITE}/${MODEL}/${DT}
@@ -53,34 +138,27 @@ export TORCHINDUCTOR_CACHE_DIR=${LOG_DIR}
53138
```
54139

55140

56-
# Detail for commands
141+
## Command Details
57142

58-
Below is the detail for those who are interested in more fine-grained control.
59-
60-
Normally, the command will be like the following:
143+
For fine-grained control, the typical command structure is as follows:
61144

62145
```Bash
63146
python benchmarks/dynamo/${SUITE}.py --only ${MODEL} --accuracy --amp -dxpu -n50 --no-skip --dashboard ${Mode_extra} --backend=inductor --timeout=4800 --output=${LOG_DIR}/${LOG_NAME}.csv
64147
```
65-
The full arg lists could be found with the following command:
148+
Full argument lists are accessible via:
66149

67150
```Bash
68151
python benchmarks/dynamo/huggingface.py --help
69152
```
70153

71-
In addition to the argument, there are configs in Python code to control the behavior:
72-
73-
74-
Please go to [torch._dynamo.config](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/config.py) and [torch._inductor.config](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py) to find all configs.
75-
76-
One example of using the config is in [Debugging Tips](#debugging-tips). Please set the config according to your need.
154+
Additional configuration settings are available in Python code, specifically in [torch._dynamo.config](https://github.com/pytorch/pytorch/blob/main/torch/_dynamo/config.py) and [torch._inductor.config](https://github.com/pytorch/pytorch/blob/main/torch/_inductor/config.py). Set these configurations as needed.
77155

78-
### Debugging Tips
156+
## Debugging Tips
79157

80158
It is recommended to set the following environment variables for debugging:
81159

82-
- `TORCHINDUCTOR_CACHE_DIR={some-dir}`: Set this for where torchinductor cache is put.
83-
- `TRITON_CACHE_DIR={some-dir}`: Where the triton cache is. By default, it is under the `TORCHINDUCTOR_CACHE_DIR/triton` folder.
160+
- `TORCHINDUCTOR_CACHE_DIR={some-dir}`: Designates the torchinductor cache location.
161+
- `TRITON_CACHE_DIR={some-dir}`: Specifies the Triton cache directory, usually within the `TORCHINDUCTOR_CACHE_DIR/triton` folder.
84162
- `TORCH_COMPILE_DEBUG_DIR={some-dir}`: Where the compile debug files be put. You could see folders like `aot_torchinductor` containing the torchinductor logs, and `torchdynamo` folder containing the dynamo log.
85163
- `TORCH_COMPILE_DEBUG=1`: Detailed for TorchInductor Tracing. It will print a lot of messages. Thus it is recommended to redirect the output to the file. By setting this flag, the re-producible Python file could be easily found.
86164

@@ -89,15 +167,15 @@ Alternatively, the above env flag could also be set in a Python file like below,
89167

90168
```Python
91169
# helps to generate descriptive kernel names
92-
torch._inductor.config.triton.ordered_kernel_names = True
93-
torch._inductor.config.triton.descriptive_kernel_names = True
170+
torch._inductor.config.triton.unique_kernel_names = True
94171
torch._inductor.config.kernel_name_max_ops = 8
95172
```
96173

97174
**Reproducing Errors with Smaller Python File**
98175

99-
Re-running from the overall model is quite a burden, you could try to reproduce the error using a smaller Python file.
100-
To reproduce the result, one could set the flag `TORCH_COMPILE_DEBUG=1`. Then the graph will be printed. Note that there are a lot of outputs, one could direct the output to a file.
176+
For efficiency, reproduce errors using a smaller Python file. Enable `TORCH_COMPILE_DEBUG=1` to generate detailed outputs, which can be redirected to a file for easier inspection. The debug folder will contain files like `fx_graph_readable.py`, `fx_graph_runnable.py`, and `output_code.py`, which can be used for further analysis and debugging.
177+
178+
Note that there are a lot of outputs, one could direct the output to a file.
101179

102180
```Bash
103181
TORCH_COMPILE_DEBUG=1 python ... &> test.log
@@ -135,15 +213,15 @@ torch._dynamo.config.repro_after="dynamo"
135213
```
136214

137215

138-
## Profiling
216+
# Profiling
139217

140-
To profile the result, one should use the `performance` mode instead of `accuracy`. i.e, One should use
218+
To profile the result, one should use the `performance` mode instead of `accuracy`, and make sure the profiler trace flag `--export-profiler-trace` is enabled in the `inductor_xpu_test.sh`. i.e, One should use
141219

142220
```Bash
143-
python benchmarks/dynamo/${SUITE}.py ... --performance ...
221+
python benchmarks/dynamo/${SUITE}.py ... --performance --export-profiler-trace...
144222
```
145-
146-
For now, we use the [profiler_legacy](https://github.com/intel/intel-extension-for-pytorch/blob/xpu-master/docs/tutorials/features/profiler_legacy.md) to catch the profiling result.
223+
## Option 1 : Use Legacy Profiling
224+
For now, we use the [profiler_legacy](https://github.com/intel/intel-extension-for-pytorch/blob/xpu-master/docs/tutorials/features/profiler_legacy.md) to catch the profiling result. We are migrating legacy profiling to kineto profiling. As the legacy profiling is more stable, it is recommended to use legacy profiling first.
147225

148226
A typical profiling code would look like below:
149227

@@ -164,9 +242,8 @@ with torch.autograd.profiler_legacy.profile(use_xpu=True) as prof:
164242
# print the result table formatted by the legacy profiler tool as your wish
165243
print(prof.key_averages().table(sort_by="self_xpu_time_total"))
166244
```
167-
### End-to-end Tests Setting:
168245

169-
#### Profiling Settings
246+
### Profiling Settings
170247

171248
For E2E tests, there are several places to change. You should cd to `pytorch/benchmarks/dynamo` and change the `common.py` as below. Note that the line number may not be the same, but the change places are unique.
172249

@@ -191,6 +268,40 @@ rgs):
191268
else:
192269
yield
193270
```
271+
## Option 2: Use Kineto Profiling
272+
We are migrating to kineto profiling. In the future, this will be the only option. A typical profiler case would like below. For now, be sure to enable the environmental flag `export IPEX_ZE_TRACING=1`.
273+
274+
```Python
275+
import torch
276+
import intel_extension_for_pytorch
277+
from torch.profiler import profile, ProfilerActivity
278+
279+
a = torch.randn(3).xpu()
280+
b = torch.randn(3).xpu()
281+
282+
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU]) as prof:
283+
c = a + b
284+
285+
print(prof.key_averages().table())
286+
```
287+
### Profiling Settings
288+
Same as the legacy profiling, you could modify the code like:
289+
290+
```diff
291+
@@ -530,7 +536,7 @@ def speedup_experiment(args, model_iter_fn, model, example_inputs, **kwargs):
292+
@contextlib.contextmanager
293+
def maybe_profile(*args, **kwargs):
294+
if kwargs.pop("enabled", True):
295+
- with torch.profiler.profile(*args, **kwargs) as p:
296+
+ with torch.autograd.profile(activities=[ProfilerActivity.CPU, ProfilerActivity.XPU], *args, **kwargs) as p:
297+
yield p
298+
else:
299+
yield
300+
```
301+
302+
303+
### End-to-end Tests Setting:
304+
194305
#### Profiling Tips
195306

196307
To run the model, you should add the `--export-profiler-trace` flag when running. Because use the profiling process will link libtorch, this will greatly reduce the kernel compiling time. It is highly recommended to **run twice** for quicker result:
@@ -203,8 +314,7 @@ If you wish to make kernel name more readable, you could enable with the followi
203314

204315
```Python
205316
# common.py
206-
torch._inductor.config.triton.ordered_kernel_names = True
207-
torch._inductor.config.triton.descriptive_kernel_names = True
317+
torch._inductor.config.triton.unique_kernel_names = True
208318
torch._inductor.config.kernel_name_max_ops = 8
209319
```
210320

0 commit comments

Comments
 (0)