Skip to content

Commit d3a8098

Browse files
yuanwu2017kaixuanliupi314ever
authored
Enable intel devices CPU/XPU/HPU for python backend (#245)
Signed-off-by: yuanwu <[email protected]> Signed-off-by: Liu, Kaixuan <[email protected]> Signed-off-by: kaixuanliu <[email protected]> Co-authored-by: Liu, Kaixuan <[email protected]> Co-authored-by: Daniel Huang <[email protected]>
1 parent 57d8fc8 commit d3a8098

File tree

17 files changed

+723
-101
lines changed

17 files changed

+723
-101
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ tracing = "0.1"
3030
serde = { version = "1.0", features = ["serde_derive"] }
3131
serde_json = "1.0"
3232
thiserror = "1.0"
33+
rand = "0.8"
3334

3435

3536
[patch.crates-io]

Dockerfile-intel

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
ARG PLATFORM=cpu
2+
FROM lukemathwalker/cargo-chef:latest-rust-1.75-bookworm AS chef
3+
WORKDIR /usr/src
4+
ENV SCCACHE=0.5.4
5+
ENV RUSTC_WRAPPER=/usr/local/bin/sccache
6+
7+
# Download and configure sccache
8+
RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \
9+
chmod +x /usr/local/bin/sccache
10+
11+
FROM chef AS planner
12+
13+
COPY backends backends
14+
COPY core core
15+
COPY router router
16+
COPY Cargo.toml ./
17+
COPY Cargo.lock ./
18+
19+
RUN cargo chef prepare --recipe-path recipe.json
20+
21+
FROM chef AS builder
22+
23+
ARG GIT_SHA
24+
ARG DOCKER_LABEL
25+
26+
# sccache specific variables
27+
ARG ACTIONS_CACHE_URL
28+
ARG ACTIONS_RUNTIME_TOKEN
29+
ARG SCCACHE_GHA_ENABLED
30+
31+
COPY --from=planner /usr/src/recipe.json recipe.json
32+
33+
RUN cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s
34+
35+
COPY backends backends
36+
COPY core core
37+
COPY router router
38+
COPY Cargo.toml ./
39+
COPY Cargo.lock ./
40+
41+
RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
42+
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
43+
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
44+
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
45+
rm -f $PROTOC_ZIP
46+
47+
FROM builder as http-builder
48+
49+
RUN cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s
50+
51+
FROM builder as grpc-builder
52+
53+
COPY proto proto
54+
55+
RUN cargo build --release --bin text-embeddings-router -F grpc -F python --no-default-features && sccache -s
56+
57+
FROM intel/intel-optimized-pytorch:2.4.0-pip-base AS cpu
58+
ENV HUGGINGFACE_HUB_CACHE=/data \
59+
PORT=80
60+
61+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
62+
build-essential \
63+
git \
64+
cmake \
65+
ninja-build \
66+
python3-dev &&\
67+
rm -rf /var/lib/apt/lists/*
68+
69+
WORKDIR /usr/src
70+
COPY backends backends
71+
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
72+
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
73+
COPY backends/python/server/requirements-intel.txt backends/python/server/requirements.txt
74+
75+
RUN python -m pip install torch==2.4.0 torchvision torchaudio==2.4.0 --index-url https://download.pytorch.org/whl/cpu
76+
77+
RUN cd backends/python/server && \
78+
make install
79+
80+
FROM vault.habana.ai/gaudi-docker/1.17.1/ubuntu22.04/habanalabs/pytorch-installer-2.3.1:latest AS hpu
81+
ENV HUGGINGFACE_HUB_CACHE=/data \
82+
PORT=80
83+
84+
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
85+
build-essential \
86+
git \
87+
cmake \
88+
ninja-build \
89+
python3-dev &&\
90+
rm -rf /var/lib/apt/lists/*
91+
92+
WORKDIR /usr/src
93+
COPY backends backends
94+
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
95+
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
96+
COPY backends/python/server/requirements-hpu.txt backends/python/server/requirements.txt
97+
98+
RUN cd backends/python/server && \
99+
make install
100+
101+
FROM intel/intel-extension-for-pytorch:2.3.110-xpu AS xpu
102+
103+
ENV HUGGINGFACE_HUB_CACHE=/data \
104+
PORT=80
105+
RUN wget http://nz2.archive.ubuntu.com/ubuntu/pool/main/o/openssl/libssl1.1_1.1.1f-1ubuntu2_amd64.deb && \
106+
dpkg -i ./libssl1.1_1.1.1f-1ubuntu2_amd64.deb
107+
108+
RUN wget -qO - https://repositories.intel.com/gpu/intel-graphics.key | gpg --dearmor | tee /usr/share/keyrings/intel-graphics.gpg > /dev/null
109+
110+
RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB \
111+
| gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list
112+
113+
RUN apt-get update && apt install -y intel-basekit xpu-smi cmake python3-dev ninja-build pciutils
114+
WORKDIR /usr/src
115+
RUN pip install torch==2.3.1+cxx11.abi torchvision==0.18.1+cxx11.abi torchaudio==2.3.1+cxx11.abi intel-extension-for-pytorch==2.3.110+xpu oneccl_bind_pt==2.3.100+xpu --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ --no-cache-dir
116+
117+
ENV CCL_ROOT=/opt/intel/oneapi/ccl/latest
118+
ENV I_MPI_ROOT=/opt/intel/oneapi/mpi/latest
119+
ENV FI_PROVIDER_PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib/prov:/usr/lib/x86_64-linux-gnu/libfabric
120+
ENV LIBRARY_PATH=/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mkl/latest/lib/:/opt/intel/oneapi/compiler/latest/lib
121+
ENV LD_LIBRARY_PATH=/opt/intel/oneapi/ccl/latest/lib/:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/lib:/opt/intel/oneapi/mpi/latest/lib:/opt/intel/oneapi/mkl/latest/lib:/opt/intel/oneapi/compiler/latest/opt/compiler/lib:/opt/intel/oneapi/compiler/latest/lib:/opt/intel/oneapi/lib:/opt/intel/oneapi/lib/intel64:
122+
ENV PATH=/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mpi/latest/bin:/opt/intel/oneapi/mpi/latest/opt/mpi/libfabric/bin:/opt/intel/oneapi/mkl/latest/bin/:/opt/intel/oneapi/compiler/latest/bin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin
123+
ENV CCL_ZE_IPC_EXCHANGE=sockets
124+
ENV CMAKE_PREFIX_PATH=/opt/intel/oneapi/mkl/latest/lib/cmake:/opt/intel/oneapi/compiler/latest
125+
ENV CPATH=/opt/intel/oneapi/mpi/latest/include:/opt/intel/oneapi/ccl/latest/include:/opt/intel/oneapi/mkl/latest/include
126+
127+
COPY backends backends
128+
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
129+
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
130+
COPY backends/python/server/requirements-intel.txt backends/python/server/requirements.txt
131+
RUN cd backends/python/server && \
132+
make install
133+
134+
FROM ${PLATFORM} AS grpc
135+
136+
COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
137+
138+
ENTRYPOINT ["text-embeddings-router"]
139+
CMD ["--json-output"]
140+
141+
FROM ${PLATFORM}
142+
143+
COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router
144+
145+
ENTRYPOINT ["text-embeddings-router"]
146+
CMD ["--json-output"]

backends/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ text-embeddings-backend-candle = { path = "candle", optional = true }
1515
text-embeddings-backend-ort = { path = "ort", optional = true }
1616
tokio = { workspace = true }
1717
tracing = { workspace = true }
18+
rand = { workspace = true }
1819

1920
[features]
2021
clap = ["dep:clap", "text-embeddings-backend-core/clap"]

backends/python/server/Makefile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ gen-server:
1515

1616
install: gen-server
1717
pip install pip --upgrade
18-
pip install torch==2.5.1
19-
pip install -r requirements.txt
18+
pip install --no-deps -r requirements.txt
2019
pip install -e .
2120

2221
run-dev:

backends/python/server/pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ opentelemetry-api = "^1.25.0"
2121
opentelemetry-exporter-otlp = "^1.25.0"
2222
opentelemetry-instrumentation-grpc = "^0.46b0"
2323
sentence-transformers = "^3.3.1"
24-
torch = "^2.5.1"
2524

2625
[tool.poetry.extras]
2726

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
2+
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
3+
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
4+
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
5+
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
6+
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
7+
filelock==3.13.4 ; python_version >= "3.9" and python_version < "3.13"
8+
fsspec==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
9+
fsspec[http]==2024.2.0 ; python_version >= "3.9" and python_version < "3.13"
10+
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
11+
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
12+
grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
13+
grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
14+
grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
15+
huggingface-hub==0.24.5 ; python_version >= "3.9" and python_version < "3.13"
16+
humanfriendly==10.0 ; python_version >= "3.9" and python_version < "3.13"
17+
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
18+
importlib-metadata==7.1.0 ; python_version >= "3.9" and python_version < "3.13"
19+
jinja2==3.1.3 ; python_version >= "3.9" and python_version < "3.13"
20+
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
21+
markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "3.13"
22+
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
23+
networkx==3.2.1 ; python_version >= "3.9" and python_version < "3.13"
24+
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.13"
25+
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
26+
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
27+
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
28+
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
29+
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
30+
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
31+
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
32+
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
33+
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
34+
optimum-habana==1.13.1 ; python_version >= "3.9" and python_version < "3.13"
35+
optimum==1.21.4 ; python_version >= "3.9" and python_version < "3.13"
36+
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
37+
pandas==2.2.2 ; python_version >= "3.9" and python_version < "3.13"
38+
pillow==10.3.0 ; python_version >= "3.9" and python_version < "3.13"
39+
protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
40+
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
41+
regex==2024.4.16 ; python_version >= "3.9" and python_version < "3.13"
42+
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
43+
safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13"
44+
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
45+
six==1.16.0 ; python_version >= "3.9" and python_version < "3.13"
46+
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
47+
tokenizers==0.19.1 ; python_version >= "3.9" and python_version < "3.13"
48+
tqdm==4.66.2 ; python_version >= "3.9" and python_version < "3.13"
49+
transformers==4.43.4 ; python_version >= "3.9" and python_version < "3.13"
50+
transformers[sentencepiece]==4.43.4 ; python_version >= "3.9" and python_version < "3.13"
51+
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
52+
typing-extensions==4.8.0 ; python_version >= "3.9" and python_version < "3.13"
53+
tzdata==2024.1 ; python_version >= "3.9" and python_version < "3.13"
54+
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
55+
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
56+
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
57+
xxhash==3.4.1 ; python_version >= "3.9" and python_version < "3.13"
58+
yarl==1.9.4 ; python_version >= "3.9" and python_version < "3.13"
59+
zipp==3.18.1 ; python_version >= "3.9" and python_version < "3.13"
60+
pyrsistent==0.20.0 ; python_version >= "3.9" and python_version < "3.13"
61+
einops==0.8.0 ; python_version >= "3.9" and python_version < "3.13"
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
backoff==2.2.1 ; python_version >= "3.9" and python_version < "3.13"
2+
certifi==2023.7.22 ; python_version >= "3.9" and python_version < "3.13"
3+
charset-normalizer==3.2.0 ; python_version >= "3.9" and python_version < "3.13"
4+
click==8.1.7 ; python_version >= "3.9" and python_version < "3.13"
5+
colorama==0.4.6 ; python_version >= "3.9" and python_version < "3.13" and (sys_platform == "win32" or platform_system == "Windows")
6+
deprecated==1.2.14 ; python_version >= "3.9" and python_version < "3.13"
7+
filelock==3.12.3 ; python_version >= "3.9" and python_version < "3.13"
8+
fsspec==2023.9.0 ; python_version >= "3.9" and python_version < "3.13"
9+
googleapis-common-protos==1.60.0 ; python_version >= "3.9" and python_version < "3.13"
10+
grpc-interceptor==0.15.3 ; python_version >= "3.9" and python_version < "3.13"
11+
grpcio-reflection==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
12+
grpcio-status==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
13+
grpcio==1.58.0 ; python_version >= "3.9" and python_version < "3.13"
14+
huggingface-hub==0.19.3 ; python_version >= "3.9" and python_version < "3.13"
15+
idna==3.4 ; python_version >= "3.9" and python_version < "3.13"
16+
jinja2==3.1.2 ; python_version >= "3.9" and python_version < "3.13"
17+
loguru==0.6.0 ; python_version >= "3.9" and python_version < "3.13"
18+
markupsafe==2.1.3 ; python_version >= "3.9" and python_version < "3.13"
19+
mpmath==1.3.0 ; python_version >= "3.9" and python_version < "3.13"
20+
networkx==3.1 ; python_version >= "3.9" and python_version < "3.13"
21+
opentelemetry-api==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
22+
opentelemetry-exporter-otlp-proto-grpc==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
23+
opentelemetry-exporter-otlp-proto-http==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
24+
opentelemetry-exporter-otlp==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
25+
opentelemetry-instrumentation-grpc==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
26+
opentelemetry-instrumentation==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
27+
opentelemetry-proto==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
28+
opentelemetry-sdk==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
29+
opentelemetry-semantic-conventions==0.36b0 ; python_version >= "3.9" and python_version < "3.13"
30+
packaging==23.1 ; python_version >= "3.9" and python_version < "3.13"
31+
protobuf==4.24.3 ; python_version >= "3.9" and python_version < "3.13"
32+
pyyaml==6.0.1 ; python_version >= "3.9" and python_version < "3.13"
33+
requests==2.31.0 ; python_version >= "3.9" and python_version < "3.13"
34+
safetensors==0.4.1 ; python_version >= "3.9" and python_version < "3.13"
35+
setuptools==68.2.0 ; python_version >= "3.9" and python_version < "3.13"
36+
sympy==1.12 ; python_version >= "3.9" and python_version < "3.13"
37+
tqdm==4.66.1 ; python_version >= "3.9" and python_version < "3.13"
38+
typer==0.6.1 ; python_version >= "3.9" and python_version < "3.13"
39+
typing-extensions==4.7.1 ; python_version >= "3.9" and python_version < "3.13"
40+
urllib3==2.0.4 ; python_version >= "3.9" and python_version < "3.13"
41+
win32-setctime==1.1.0 ; python_version >= "3.9" and python_version < "3.13" and sys_platform == "win32"
42+
wrapt==1.15.0 ; python_version >= "3.9" and python_version < "3.13"
43+
transformers==4.40.0 ; python_version >= "3.9" and python_version < "3.13"
44+
pyrsistent==0.20.0 ; python_version >= "3.9" and python_version < "3.13"

backends/python/server/text_embeddings_server/cli.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ def serve(
4848

4949
# Downgrade enum into str for easier management later on
5050
dtype = None if dtype is None else dtype.value
51-
5251
server.serve(model_path, dtype, uds_path, pool)
5352

5453

backends/python/server/text_embeddings_server/models/__init__.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
from text_embeddings_server.models.model import Model
1010
from text_embeddings_server.models.default_model import DefaultModel
11+
from text_embeddings_server.utils.device import get_device, use_ipex
1112

1213
__all__ = ["Model"]
1314

@@ -27,33 +28,42 @@
2728

2829
def get_model(model_path: Path, dtype: Optional[str], pool: str):
2930
if dtype == "float32":
30-
dtype = torch.float32
31+
datatype = torch.float32
3132
elif dtype == "float16":
32-
dtype = torch.float16
33+
datatype = torch.float16
3334
elif dtype == "bfloat16":
34-
dtype = torch.bfloat16
35+
datatype = torch.bfloat16
3536
else:
3637
raise RuntimeError(f"Unknown dtype {dtype}")
3738

38-
if torch.cuda.is_available():
39-
device = torch.device("cuda")
40-
else:
41-
device = torch.device("cpu")
39+
device = get_device()
40+
logger.info(f"backend device: {device}")
4241

4342
config = AutoConfig.from_pretrained(model_path)
44-
4543
if config.model_type == "bert":
4644
config: BertConfig
4745
if (
4846
device.type == "cuda"
4947
and config.position_embedding_type == "absolute"
50-
and dtype in [torch.float16, torch.bfloat16]
48+
and datatype in [torch.float16, torch.bfloat16]
5149
and FLASH_ATTENTION
5250
):
5351
if pool != "cls":
5452
raise ValueError("FlashBert only supports cls pooling")
55-
return FlashBert(model_path, device, dtype)
56-
else:
57-
return DefaultModel(model_path, device, dtype, pool)
53+
return FlashBert(model_path, device, datatype) # type: ignore
54+
if use_ipex() or device.type == "hpu":
55+
return FlashBert(model_path, device, datatype) # type: ignore
56+
57+
return DefaultModel(model_path, device, datatype)
58+
else:
59+
if device.type == "hpu":
60+
from habana_frameworks.torch.hpu import wrap_in_hpu_graph
61+
from optimum.habana.transformers.modeling_utils import (
62+
adapt_transformers_to_gaudi,
63+
)
5864

59-
raise NotImplementedError
65+
adapt_transformers_to_gaudi()
66+
model_handle = DefaultModel(model_path, device, datatype)
67+
model_handle.model = wrap_in_hpu_graph(model_handle.model)
68+
return model_handle
69+
return DefaultModel(model_path, device, datatype)

0 commit comments

Comments
 (0)