Skip to content

Commit fdf4a93

Browse files
committed
enable IPEX optimization
Signed-off-by: kta-intel <[email protected]>
1 parent 2470e7d commit fdf4a93

File tree

5 files changed

+151
-0
lines changed

5 files changed

+151
-0
lines changed

Dockerfile

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22
ARG BASE_UBI_IMAGE_TAG=9.3-1361.1699548029
33
ARG PROTOC_VERSION=25.0
44
ARG PYTORCH_INDEX="https://download.pytorch.org/whl"
5+
ARG IPEX_INDEX="https://pytorch-extension.intel.com/release-whl/stable/cpu/us/"
56
#ARG PYTORCH_INDEX="https://download.pytorch.org/whl/nightly"
67
ARG PYTORCH_VERSION=2.1.0
8+
ARG IPEX_VERSION=2.1.0
79

810
## Base Layer ##################################################################
911
FROM registry.access.redhat.com/ubi9/ubi:${BASE_UBI_IMAGE_TAG} as base
@@ -148,6 +150,7 @@ WORKDIR /usr/src
148150

149151
# Install specific version of torch
150152
RUN pip install torch=="$PYTORCH_VERSION+cpu" --index-url "${PYTORCH_INDEX}/cpu" --no-cache-dir
153+
RUN pip install intel-extension-for-pytorch=="$IPEX_VERSION" --extra-index-url "${IPEX_INDEX}" --no-cache-dir
151154

152155
COPY server/Makefile server/Makefile
153156

@@ -174,6 +177,8 @@ RUN cd integration_tests && make install
174177
FROM cuda-devel as python-builder
175178
ARG PYTORCH_INDEX
176179
ARG PYTORCH_VERSION
180+
ARG IPEX_INDEX
181+
ARG IPEX_VERSION
177182

178183
RUN dnf install -y unzip git ninja-build && dnf clean all
179184

@@ -187,6 +192,7 @@ ENV PATH=/opt/miniconda/bin:$PATH
187192
# Install specific version of torch
188193
RUN pip install ninja==1.11.1.1 --no-cache-dir
189194
RUN pip install torch==$PYTORCH_VERSION+cu118 --index-url "${PYTORCH_INDEX}/cu118" --no-cache-dir
195+
RUN pip install intel-extension-for-pytorch~="$IPEX_VERSION" --extra-index-url "${IPEX_INDEX}" --no-cache-dir
190196

191197

192198
## Build flash attention v2 ####################################################
@@ -241,6 +247,14 @@ COPY --from=flash-att-builder /usr/src/flash-attention/csrc/rotary/build /usr/sr
241247
FROM base as flash-att-v2-cache
242248
COPY --from=flash-att-v2-builder /usr/src/flash-attention-v2/build /usr/src/flash-attention-v2/build
243249

250+
## Setup environment variables for performance on Xeon
251+
ENV KMP_BLOCKTIME=INF
252+
ENV KMP_TPAUSE=0
253+
ENV KMP_SETTINGS=1
254+
ENV KMP_AFFINITY=granularity=fine,compact,1,0
255+
ENV KMP_FORJOIN_BARRIER_PATTERN=dist,dist
256+
ENV KMP_PLAIN_BARRIER_PATTERN=dist,dist
257+
ENV KMP_REDUCTION_BARRIER_PATTERN=dist,dist
244258

245259
## Final Inference Server image ################################################
246260
FROM cuda-runtime as server-release

README.md

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,3 +158,53 @@ They are all prefixed with `tgi_`. Descriptions will be added to the table below
158158
| `tgi_tokenize_request_input_count` | `counter` | | |
159159
| `tgi_tokenize_request_tokens` | `histogram` | | |
160160
| `tgi_tokenize_request_duration` | `histogram` | | |
161+
162+
### Run Inference Locally with Intel(R) Extension for PyTorch*
163+
164+
#### 0. Build the image
165+
166+
```
167+
make build
168+
```
169+
170+
This command will print the Docker image id for `text-gen-server`. Set `IMAGE_ID` in the commands below to this.
171+
172+
#### 1. Run the server
173+
174+
```
175+
export IMAGE_ID=<image_id>
176+
export MODEL=<model>
177+
export volume=$PWD/data
178+
mkdir $volume
179+
chmod 777 volume
180+
```
181+
182+
It's possible to use `text-generation-server download-weights`, but in this example we use a model that we download locally with `transformers-cli`.
183+
184+
```
185+
transformers-cli download $MODEL
186+
```
187+
188+
Move model from `~/.cache/huggingface/hub/` to `$volume` You can then run the inference server with:
189+
190+
```
191+
docker run -p 8033:8033 -p 3000:3000 -e TRANSFORMERS_CACHE=/data -e HUGGINGFACE_HUB_CACHE=/data -e DEPLOYMENT_FRAMEWORK=hf_transformers_ipex -e MODEL_NAME=$MODEL -v $volume:/data $IMAGE_ID text-generation-launcher --dtype-str bfloat16
192+
```
193+
194+
#### 2. Prepare the client
195+
196+
Install GRPC in a Python environment: `pip install grpcio grpcio-tools`
197+
198+
In the repository root, run:
199+
```
200+
python -m grpc_tools.protoc -Iproto --python_out=pb --pyi_out=pb --grpc_python_out=pb proto/generate.proto
201+
python -m grpc_tools.protoc -Iproto --python_out=pb --pyi_out=pb --grpc_python_out=pb proto/generation.proto
202+
```
203+
This generates the necessary files in the pb directory.
204+
205+
Then to run inference:
206+
```
207+
python pb/client.py
208+
```
209+
210+
Edit `pb/client.py` to change the prompts.

pb/client.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import json
2+
import time
3+
4+
import grpc
5+
import requests
6+
from google.protobuf import json_format
7+
8+
import generation_pb2 as pb2
9+
import generation_pb2_grpc as gpb2
10+
11+
port = 8033
12+
channel = grpc.insecure_channel(f"localhost:{port}")
13+
stub = gpb2.GenerationServiceStub(channel)
14+
15+
# warmup inference
16+
for i in range (5):
17+
text = "hello world"
18+
message = json_format.ParseDict(
19+
{"requests": [{"text": text}]}, pb2.BatchedGenerationRequest()
20+
)
21+
response = stub.Generate(message)
22+
23+
# time inference
24+
for prompt in ["The weather is", "The cat is walking on", "I would like to"]:
25+
# for prompt in ["def hello_world():"]:
26+
message = json_format.ParseDict(
27+
{"requests": [{"text": prompt}]}, pb2.BatchedGenerationRequest()
28+
)
29+
start = time.perf_counter()
30+
response = stub.Generate(message)
31+
end = time.perf_counter()
32+
print(prompt, response)
33+
print(f"Duration: {end-start:.2f}")
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
import os
2+
import torch
3+
import intel_extension_for_pytorch as ipex
4+
from transformers.models.auto.auto_factory import _BaseAutoModelClass
5+
6+
from text_generation_server.inference_engine.engine import BaseInferenceEngine
7+
from text_generation_server.utils.hub import TRUST_REMOTE_CODE
8+
from typing import Any, Optional
9+
10+
11+
class InferenceEngine(BaseInferenceEngine):
12+
def __init__(
13+
self,
14+
model_path: str,
15+
model_class: type[_BaseAutoModelClass],
16+
dtype: torch.dtype,
17+
model_config: Optional[Any]
18+
) -> None:
19+
super().__init__(model_path, model_config)
20+
21+
kwargs = {
22+
"pretrained_model_name_or_path": model_path,
23+
"local_files_only": True,
24+
"trust_remote_code": TRUST_REMOTE_CODE,
25+
"torchscript": 'jit',
26+
"torch_dtype": dtype
27+
}
28+
29+
if model_config.model_type == "mpt":
30+
model_config.init_device = str(self.device)
31+
kwargs["config"] = model_config
32+
33+
try:
34+
ipex._C.disable_jit_linear_repack()
35+
except Exception:
36+
pass
37+
38+
torch._C._jit_set_texpr_fuser_enabled(False)
39+
40+
slow_but_exact = os.getenv('BLOOM_SLOW_BUT_EXACT', 'false').lower() == 'true'
41+
if slow_but_exact:
42+
kwargs["slow_but_exact"] = True
43+
44+
with self.device:
45+
self.model = model_class.from_pretrained(**kwargs).requires_grad_(False).eval()
46+
47+
self.model = self.model.to(memory_format=torch.channels_last)
48+
self.model = ipex.optimize_transformers(self.model, dtype=dtype, inplace=True)
49+
print('Intel(R) Extension for PyTorch* enabled')
50+
51+
self.model.to(self.device)

server/text_generation_server/models/causal_lm.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -575,6 +575,9 @@ def __init__(
575575
_, past_key_values, _ = self.forward(input_ids=one_token, attention_mask=one_token)
576576
if torch.is_tensor(past_key_values[0]):
577577
self.batch_type = CombinedKVCausalLMBatch
578+
elif 'ipex' in deployment_framework:
579+
print(deployment_framework)
580+
self.batch_type = CausalLMBatch
578581
else:
579582
# check the ordering of the key tensor dimensions
580583
key_past, value_past = past_key_values[0]

0 commit comments

Comments
 (0)