|
| 1 | +# Intel® Extension for PyTorch\* Large Language Model (LLM) Feature Get Started For Qwen2 models |
| 2 | + |
| 3 | +Intel® Extension for PyTorch\* provides dedicated optimization for running Qwen2 models faster, including technical points like paged attention, ROPE fusion, etc. And a set of data types are supported for various scenarios, including BF16, Weight Only Quantization, etc. |
| 4 | +# 1. Environment Setup |
| 5 | + |
| 6 | +There are several environment setup methodologies provided. You can choose either of them according to your usage scenario. The Docker-based ones are recommended. |
| 7 | + |
| 8 | +## 1.1 [RECOMMENDED] Docker-based environment setup with pre-built wheels |
| 9 | + |
| 10 | +```bash |
| 11 | +# Get the Intel® Extension for PyTorch* source code |
| 12 | +git clone https://github.com/intel/intel-extension-for-pytorch.git |
| 13 | +cd intel-extension-for-pytorch |
| 14 | +git checkout 2.3-qwen-2 |
| 15 | +git submodule sync |
| 16 | +git submodule update --init --recursive |
| 17 | + |
| 18 | +# Build an image with the provided Dockerfile by installing from Intel® Extension for PyTorch* prebuilt wheel files |
| 19 | +DOCKER_BUILDKIT=1 docker build -f examples/cpu/inference/python/llm/Dockerfile -t ipex-llm:qwen2 . |
| 20 | + |
| 21 | +# Run the container with command below |
| 22 | +docker run --rm -it --privileged ipex-llm:qwen2 bash |
| 23 | + |
| 24 | +# When the command prompt shows inside the docker container, enter llm examples directory |
| 25 | +cd llm |
| 26 | + |
| 27 | +# Activate environment variables |
| 28 | +source ./tools/env_activate.sh |
| 29 | +``` |
| 30 | + |
| 31 | +## 1.2 Conda-based environment setup with pre-built wheels |
| 32 | + |
| 33 | +```bash |
| 34 | +# Get the Intel® Extension for PyTorch* source code |
| 35 | +git clone https://github.com/intel/intel-extension-for-pytorch.git |
| 36 | +cd intel-extension-for-pytorch |
| 37 | +git checkout 2.3-qwen-2 |
| 38 | +git submodule sync |
| 39 | +git submodule update --init --recursive |
| 40 | + |
| 41 | +# Create a conda environment (pre-built wheel only available with python=3.10) |
| 42 | +conda create -n llm python=3.10 -y |
| 43 | +conda activate llm |
| 44 | + |
| 45 | +# Setup the environment with the provided script |
| 46 | +# A sample "prompt.json" file for benchmarking is also downloaded |
| 47 | +cd examples/cpu/inference/python/llm |
| 48 | +bash ./tools/env_setup.sh 7 |
| 49 | + |
| 50 | +# Activate environment variables |
| 51 | +source ./tools/env_activate.sh |
| 52 | +``` |
| 53 | +<br> |
| 54 | + |
| 55 | +# 2. How To Run Qwen2 with ipex.llm |
| 56 | + |
| 57 | +**ipex.llm provides a single script to facilitate running generation tasks as below:** |
| 58 | + |
| 59 | +``` |
| 60 | +# if you are using a docker container built from commands above in Sec. 1.1, the placeholder LLM_DIR below is /home/ubuntu/llm |
| 61 | +# if you are using a conda env created with commands above in Sec. 1.2, the placeholder LLM_DIR below is intel-extension-for-pytorch/examples/cpu/inference/python/llm |
| 62 | +cd <LLM_DIR> |
| 63 | +python run.py --help # for more detailed usages |
| 64 | +``` |
| 65 | + |
| 66 | +| Key args of run.py | Notes | |
| 67 | +|---|---| |
| 68 | +| model id | `--model-name-or-path` or `-m` to specify the <QWEN2_MODEL_ID_OR_LOCAL_PATH>, it is model id from Huggingface or downloaded local path | |
| 69 | +| generation | default: beam search (beam size = 4), `--greedy` for greedy search | |
| 70 | +| input tokens | provide fixed sizes for input prompt size, use `--input-tokens` for <INPUT_LENGTH> in [1024, 2048, 4096, 8192, 16384, 32768]; if `--input-tokens` is not used, use `--prompt` to choose other strings as prompt inputs| |
| 71 | +| output tokens | default: 32, use `--max-new-tokens` to choose any other size | |
| 72 | +| batch size | default: 1, use `--batch-size` to choose any other size | |
| 73 | +| token latency | enable `--token-latency` to print out the first or next token latency | |
| 74 | +| generation iterations | use `--num-iter` and `--num-warmup` to control the repeated iterations of generation, default: 100-iter/10-warmup | |
| 75 | +| streaming mode output | greedy search only (work with `--greedy`), use `--streaming` to enable the streaming generation output | |
| 76 | + |
| 77 | +*Note:* You may need to log in your HuggingFace account to access the model files. Please refer to [HuggingFace login](https://huggingface.co/docs/huggingface_hub/quick-start#login). |
| 78 | + |
| 79 | +## 2.1 Usage of running Qwen2 models |
| 80 | + |
| 81 | +The *<QWEN2_MODEL_ID_OR_LOCAL_PATH>* in the below commands specifies the Qwen2 model you will run, which can be found from [HuggingFace Models](https://huggingface.co/models). |
| 82 | + |
| 83 | +### 2.1.1 Run generation with multiple instances on multiple CPU numa nodes |
| 84 | + |
| 85 | +#### 2.1.1.1 Prepare: |
| 86 | + |
| 87 | +```bash |
| 88 | +unset KMP_AFFINITY |
| 89 | +``` |
| 90 | + |
| 91 | +In the DeepSpeed cases below, we recommend `--shard-model` to shard model weight sizes more even for better memory usage when running with DeepSpeed. |
| 92 | + |
| 93 | +If using `--shard-model`, it will save a copy of the shard model weights file in the path of `--output-dir` (default path is `./saved_results` if not provided). |
| 94 | +If you have used `--shard-model` and generated such a shard model path (or your model weights files are already well sharded), in further repeated benchmarks, please remove `--shard-model`, and replace `-m <QWEN2_MODEL_ID_OR_LOCAL_PATH>` with `-m <shard model path>` to skip the repeated shard steps. |
| 95 | + |
| 96 | +Besides, the standalone shard model function/scripts are also provided in section 2.1.1.4, in case you would like to generate the shard model weights files in advance before running distributed inference. |
| 97 | + |
| 98 | +#### 2.1.1.2 BF16: |
| 99 | + |
| 100 | +- Command: |
| 101 | +```bash |
| 102 | +deepspeed --bind_cores_to_rank run.py --benchmark -m <QWEN2_MODEL_ID_OR_LOCAL_PATH> --dtype bfloat16 --ipex --greedy --input-tokens <INPUT_LENGTH> --autotp --shard-model |
| 103 | +``` |
| 104 | + |
| 105 | +#### 2.1.1.3 Weight-only quantization (INT8): |
| 106 | + |
| 107 | +By default, for weight-only quantization, we use quantization with [Automatic Mixed Precision](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html) inference (`--quant-with-amp`) to get peak performance and fair accuracy. |
| 108 | +For weight-only quantization with deepspeed, we quantize the model then run the benchmark. The quantized model won't be saved. |
| 109 | + |
| 110 | +- Command: |
| 111 | +```bash |
| 112 | +deepspeed --bind_cores_to_rank run.py --benchmark -m <QWEN2_MODEL_ID_OR_LOCAL_PATH> --ipex --ipex-weight-only-quantization --weight-dtype INT8 --quant-with-amp --greedy --input-tokens <INPUT_LENGTH> --autotp --shard-model |
| 113 | +``` |
| 114 | + |
| 115 | +#### 2.1.1.4 How to Shard Model weight files for Distributed Inference with DeepSpeed |
| 116 | + |
| 117 | +To save memory usage, we could shard the model weights files under the local path before we launch distributed tests with DeepSpeed. |
| 118 | + |
| 119 | +``` |
| 120 | +cd ./utils |
| 121 | +# general command: |
| 122 | +python create_shard_model.py -m <QWEN2_MODEL_ID_OR_LOCAL_PATH> --save-path ./local_qwen2_model_shard |
| 123 | +# After sharding the model, using "-m ./local_qwen2_model_shard" in later tests |
| 124 | +``` |
| 125 | + |
| 126 | +### 2.1.2 Run generation with single instance on a single numa node |
| 127 | +#### 2.1.2.1 BF16: |
| 128 | + |
| 129 | +- Command: |
| 130 | +```bash |
| 131 | +OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <physical cores list> python run.py --benchmark -m <QWEN2_MODEL_ID_OR_LOCAL_PATH> --dtype bfloat16 --ipex --greedy --input-tokens <INPUT_LENGTH> |
| 132 | +``` |
| 133 | + |
| 134 | +#### 2.1.2.2 Weight-only quantization (INT8): |
| 135 | + |
| 136 | +By default, for weight-only quantization, we use quantization with [Automatic Mixed Precision](https://pytorch.org/tutorials/recipes/recipes/amp_recipe.html) inference (`--quant-with-amp`) to get peak performance and fair accuracy. |
| 137 | + |
| 138 | +- Command: |
| 139 | +```bash |
| 140 | +OMP_NUM_THREADS=<physical cores num> numactl -m <node N> -C <physical cores list> python run.py --benchmark -m <QWEN2_MODEL_ID_OR_LOCAL_PATH> --ipex-weight-only-quantization --weight-dtype INT8 --quant-with-amp --output-dir "saved_results" --greedy --input-tokens <INPUT_LENGTH> |
| 141 | +``` |
| 142 | + |
| 143 | +#### 2.1.2.3 Notes: |
| 144 | + |
| 145 | +(1) [`numactl`](https://linux.die.net/man/8/numactl) is used to specify memory and cores of your hardware to get better performance. *<node N>* specifies the [numa](https://en.wikipedia.org/wiki/Non-uniform_memory_access) node id (e.g., 0 to use the memory from the first numa node). *<physical cores list>* specifies phsysical cores which you are using from the *<node N>* numa node. You can use [`lscpu`](https://man7.org/linux/man-pages/man1/lscpu.1.html) command in Linux to check the numa node information. |
| 146 | + |
| 147 | +(2) For all quantization benchmarks, both quantization and inference stages will be triggered by default. For quantization stage, it will auto-generate the quantized model named `best_model.pt` in the `--output-dir` path, and for inference stage, it will launch the inference with the quantized model `best_model.pt`. For inference-only benchmarks (avoid the repeating quantization stage), you can also reuse these quantized models for by adding `--quantized-model-path <output_dir + "best_model.pt">`. |
| 148 | + |
| 149 | +## Miscellaneous Tips |
| 150 | +Intel® Extension for PyTorch\* also provides dedicated optimization for many other Large Language Models (LLM), which cover a set of data types that are supported for various scenarios. For more details, please check this [Intel® Extension for PyTorch\* doc](https://github.com/intel/intel-extension-for-pytorch/blob/release/2.3/README.md). |
0 commit comments