Skip to content
Open
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
710b8c1
1st draft
JingyaHuang Aug 22, 2025
cc84f29
Merge branch 'main' into add-neuron-backend
JingyaHuang Oct 7, 2025
139b179
feat: sentence transformer for neuron
JingyaHuang Oct 22, 2025
dd0c08d
fix: neuron dockerfile
JingyaHuang Oct 27, 2025
1e4f3c9
remove useless
JingyaHuang Oct 28, 2025
adfa2e9
Merge branch 'main' into add-neuron-backend
JingyaHuang Oct 28, 2025
a25cf98
fix dockerfile
JingyaHuang Oct 31, 2025
56c15d8
neuron path
JingyaHuang Nov 3, 2025
142520a
fix container env + Neuron related changes
JingyaHuang Nov 3, 2025
7ada877
fix for neuron backend + tests
JingyaHuang Feb 3, 2026
976b71c
add to CI & add pre-compiled test
JingyaHuang Feb 4, 2026
dc3edc2
fix tests
JingyaHuang Feb 4, 2026
3676b94
Merge branch 'main' into add-neuron-backend
JingyaHuang Feb 4, 2026
b803566
snol fix
JingyaHuang Feb 5, 2026
81c57d3
fix doc index
JingyaHuang Feb 5, 2026
7f517b9
fix style
JingyaHuang Feb 5, 2026
9752998
build and push neuron docker images in CI
JingyaHuang Feb 5, 2026
c517aa2
smol changes
JingyaHuang Feb 5, 2026
d1708a3
Merge branch 'huggingface:main' into add-neuron-backend
JingyaHuang Feb 10, 2026
08301f0
Merge branch 'main' into add-neuron-backend
Feb 17, 2026
37519d9
Merge branch 'main' into add-neuron-backend
JingyaHuang Feb 20, 2026
533d853
Update Dockerfile-neuron
JingyaHuang Feb 20, 2026
0829b6f
Apply suggestions from code review
JingyaHuang Feb 23, 2026
aa47549
Merge branch 'main' into add-neuron-backend
JingyaHuang Feb 23, 2026
1464cc3
review:suggestions
JingyaHuang Feb 23, 2026
9961846
Merge branch 'add-neuron-backend' of github.com:JingyaHuang/text-embe…
JingyaHuang Feb 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ on:
- "Cargo.lock"
- "rust-toolchain.toml"
- "Dockerfile"
- "Dockerfile-neuron"
branches:
- "main"

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
name: Run integration tests
name: Run Habana integration tests

on:
workflow_dispatch:
Expand Down Expand Up @@ -28,4 +28,4 @@ jobs:
working-directory: integration_tests
run: |
uv sync --locked --all-extras --dev
uv run pytest --durations=0 -sv .
uv run pytest --durations=0 -sv gaudi/
33 changes: 33 additions & 0 deletions .github/workflows/integration-test-neuron.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: Run Neuron integration tests

on:
workflow_dispatch:
schedule:
- cron: '0 0 * * *' # Run the workflow nightly to check Neuron integration is working

jobs:
tests:
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
runs-on:
group: aws-inf2-8xlarge
steps:
- name: Checkout repository
uses: actions/checkout@v4

- name: Install uv
uses: astral-sh/setup-uv@v5

- name: Build Docker image for Neuron
run: |
docker build . -f Dockerfile-neuron -t tei-neuron

- name: Run integration tests
working-directory: integration_tests
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
DOCKER_IMAGE: tei-neuron
run: |
uv sync --locked --all-extras --dev
uv run pytest --durations=0 -sv neuron/
8 changes: 8 additions & 0 deletions .github/workflows/matrix.json
Original file line number Diff line number Diff line change
Expand Up @@ -105,5 +105,13 @@
"extraBuildArgs": "PLATFORM=hpu",
"grpc": true,
"dockerfile": "Dockerfile-intel"
},
{
"name": "neuron",
"imageNamePrefix": "neuron-",
"runOn": "always",
"sccache": true,
"grpc": true,
"dockerfile": "Dockerfile-neuron"
}
]
190 changes: 190 additions & 0 deletions Dockerfile-neuron
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
FROM lukemathwalker/cargo-chef:latest-rust-1.85-bookworm AS chef
WORKDIR /usr/src

ENV SCCACHE=0.10.0
ENV RUSTC_WRAPPER=/usr/local/bin/sccache

# Download, configure sccache
RUN curl -fsSL https://github.com/mozilla/sccache/releases/download/v$SCCACHE/sccache-v$SCCACHE-x86_64-unknown-linux-musl.tar.gz | tar -xzv --strip-components=1 -C /usr/local/bin sccache-v$SCCACHE-x86_64-unknown-linux-musl/sccache && \
chmod +x /usr/local/bin/sccache

FROM chef AS planner

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

RUN cargo chef prepare --recipe-path recipe.json

FROM chef AS builder

ARG GIT_SHA
ARG DOCKER_LABEL

# sccache specific variables
ARG SCCACHE_GHA_ENABLED

COPY --from=planner /usr/src/recipe.json recipe.json

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo chef cook --release --features python --no-default-features --recipe-path recipe.json && sccache -s

COPY backends backends
COPY core core
COPY router router
COPY Cargo.toml ./
COPY Cargo.lock ./

RUN PROTOC_ZIP=protoc-21.12-linux-x86_64.zip && \
curl -OL https://github.com/protocolbuffers/protobuf/releases/download/v21.12/$PROTOC_ZIP && \
unzip -o $PROTOC_ZIP -d /usr/local bin/protoc && \
unzip -o $PROTOC_ZIP -d /usr/local 'include/*' && \
rm -f $PROTOC_ZIP

FROM builder AS http-builder

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo build --release --bin text-embeddings-router -F python -F http --no-default-features && sccache -s

FROM builder AS grpc-builder

COPY proto proto

RUN --mount=type=secret,id=actions_results_url,env=ACTIONS_RESULTS_URL \
--mount=type=secret,id=actions_runtime_token,env=ACTIONS_RUNTIME_TOKEN \
cargo build --release --bin text-embeddings-router -F grpc -F python --no-default-features && sccache -s

FROM public.ecr.aws/docker/library/ubuntu:22.04 AS neuron

ENV HUGGINGFACE_HUB_CACHE=/data \
PORT=80

ENV PATH="/usr/local/bin:/root/.local/bin:${PATH}"

RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
python3 \
python3-pip \
python3-dev \
build-essential \
git \
curl \
cmake \
pkg-config \
protobuf-compiler \
ninja-build \
&& rm -rf /var/lib/apt/lists/*

RUN ln -s /usr/bin/python3 /usr/local/bin/python || true
RUN ln -s /usr/bin/pip3 /usr/local/bin/pip || true

WORKDIR /usr/src
COPY backends backends
COPY backends/python/server/text_embeddings_server/models/__init__.py backends/python/server/text_embeddings_server/models/__init__.py
COPY backends/python/server/pyproject.toml backends/python/server/pyproject.toml
RUN cd backends/python/server && \
make install

ARG NEURONX_COLLECTIVES_LIB_VERSION=2.28.27.0-bc30ece58
ARG NEURONX_RUNTIME_LIB_VERSION=2.28.23.0-dd5879008
ARG NEURONX_TOOLS_VERSION=2.26.14.0

ARG NEURONX_CC_VERSION=2.21.33363.0+82129205
ARG NEURONX_FRAMEWORK_VERSION=2.8.0.2.10.16998+e9bf8a50
ARG NEURONX_DISTRIBUTED_VERSION=0.15.22404+1f27bddf

RUN apt-get update \
&& apt-get upgrade -y \
&& apt-get install -y --no-install-recommends \
apt-transport-https \
build-essential \
ca-certificates \
cmake \
curl \
emacs \
git \
gnupg2 \
gpg-agent \
jq \
libgl1-mesa-glx \
libglib2.0-0 \
libsm6 \
libxext6 \
libxrender-dev \
libcap-dev \
libhwloc-dev \
openjdk-11-jdk \
unzip \
vim \
wget \
zlib1g-dev \
&& rm -rf /var/lib/apt/lists/* \
&& rm -rf /tmp/tmp* \
&& apt-get clean

# Ubuntu 22.04 = jammy; use signed-by (apt-key is deprecated)
RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | gpg --dearmor -o /usr/share/keyrings/neuron-archive-keyring.gpg && \
echo "deb [signed-by=/usr/share/keyrings/neuron-archive-keyring.gpg] https://apt.repos.neuron.amazonaws.com jammy main" > /etc/apt/sources.list.d/neuron.list

RUN apt-get update \
&& apt-get install -y \
aws-neuronx-tools=$NEURONX_TOOLS_VERSION \
aws-neuronx-collectives=$NEURONX_COLLECTIVES_LIB_VERSION \
aws-neuronx-runtime-lib=$NEURONX_RUNTIME_LIB_VERSION \
&& rm -rf /var/lib/apt/lists/* \
&& rm -rf /tmp/tmp* \
&& apt-get clean

ENV PATH="/opt/aws/neuron/bin:${PATH}"

RUN pip install --index-url https://pip.repos.neuron.amazonaws.com \
--extra-index-url https://pypi.org/simple \
--trusted-host pip.repos.neuron.amazonaws.com \
neuronx-cc==$NEURONX_CC_VERSION \
torch-neuronx==$NEURONX_FRAMEWORK_VERSION \
torchvision \
neuronx_distributed==$NEURONX_DISTRIBUTED_VERSION \
&& rm -rf ~/.cache/pip/*

# HF ARGS
# Note: optimum-neuron 0.4.4 requires transformers~=4.57.1
ARG TRANSFORMERS_VERSION=4.57.1
ARG DIFFUSERS_VERSION=0.35.2
ARG HUGGINGFACE_HUB_VERSION=0.36.0
ARG OPTIMUM_NEURON_VERSION=0.4.4
ARG SENTENCE_TRANSFORMERS=5.1.2
ARG PEFT_VERSION=0.17.0
ARG DATASETS_VERSION=4.1.1

# Install Hugging Face libraries and dependencies for TEI on Neuron
RUN pip install --no-cache-dir -U \
networkx==2.8.8 \
transformers[sentencepiece,audio,vision]==${TRANSFORMERS_VERSION} \
diffusers==${DIFFUSERS_VERSION} \
compel \
controlnet-aux \
huggingface_hub==${HUGGINGFACE_HUB_VERSION} \
hf_transfer \
datasets==${DATASETS_VERSION} \
optimum-neuron==${OPTIMUM_NEURON_VERSION} \
sentence_transformers==${SENTENCE_TRANSFORMERS} \
peft==${PEFT_VERSION} \
&& rm -rf ~/.cache/pip/*
Comment on lines +163 to +175
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not something to tackle in this PR maybe, but I'd rather rely on a lock file here instead of those, so it might be worth consider re-opening #587?

cc @regisss and @kaixuanliu as this was something mentioned in the past, but apparently it was failing on Intel HPUs (?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should work on HPU, not sure why it failed at that time. so don't hesitate to go that way, and if you have a lock file you would like me to test on HPU, happy to do it :)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @regisss, I'll restart Nico's PR to add uv support instead, and ping you when done for testing 🤗



FROM neuron AS grpc

COPY --from=grpc-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]

FROM neuron

COPY --from=http-builder /usr/src/target/release/text-embeddings-router /usr/local/bin/text-embeddings-router

ENTRYPOINT ["text-embeddings-router"]
CMD ["--json-output"]
60 changes: 51 additions & 9 deletions backends/python/server/text_embeddings_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,8 @@
from text_embeddings_server.models.masked_model import MaskedLanguageModel
from text_embeddings_server.models.default_model import DefaultModel
from text_embeddings_server.models.classification_model import ClassificationModel
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
from text_embeddings_server.models.flash_mistral import FlashMistral
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
from text_embeddings_server.utils.device import get_device, use_ipex

from text_embeddings_server.utils.device import get_device, use_ipex, is_neuron

__all__ = ["Model"]

Expand All @@ -23,19 +21,45 @@
"true",
"1",
]
# Disable gradients
torch.set_grad_enabled(False)

# Flash Attention models - only available when flash_attn is installed
FLASH_ATTENTION = True
FlashBert = None
FlashJinaBert = None
FlashMistral = None
FlashQwen3 = None

try:
from text_embeddings_server.models.flash_bert import FlashBert
from text_embeddings_server.models.jinaBert_model import FlashJinaBert
from text_embeddings_server.models.flash_mistral import FlashMistral
from text_embeddings_server.models.flash_qwen3 import FlashQwen3
# Disable gradients
torch.set_grad_enabled(False)
except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}")
FLASH_ATTENTION = False

if FLASH_ATTENTION:
__all__.append(FlashBert)

# Neuron models - only import when on Neuron device to avoid unnecessary dependencies
NeuronSentenceTransformersModel = None
NeuronClassificationModel = None
NeuronMaskedLMModel = None
create_neuron_model = None

if is_neuron():
try:
from text_embeddings_server.models.neuron_models import (
NeuronSentenceTransformersModel,
NeuronClassificationModel,
NeuronMaskedLMModel,
create_neuron_model,
)
except ImportError as e:
logger.warning(f"Could not import Neuron models: {e}")


def wrap_model_if_hpu(model_handle, device):
"""Wrap the model in HPU graph if the device is HPU."""
Expand Down Expand Up @@ -75,8 +99,26 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):

config = AutoConfig.from_pretrained(model_path, trust_remote_code=TRUST_REMOTE_CODE)

# Neuron cases - use optimum-neuron for all supported model types
if is_neuron():
logger.info(f"Neuron device detected, using optimum-neuron backend for model type: {config.model_type}")
try:
return create_neuron_model(
model_path=model_path,
device=device,
dtype=datatype,
pool=pool,
trust_remote=TRUST_REMOTE_CODE,
config=config,
)
except Exception as e:
logger.warning(f"Failed to load model with optimum-neuron: {e}")
logger.warning("Falling back to default model loading path")
# Fall through to default model loading

if (
hasattr(config, "auto_map")
FlashJinaBert is not None
and hasattr(config, "auto_map")
and isinstance(config.auto_map, dict)
and "AutoModel" in config.auto_map
and config.auto_map["AutoModel"]
Expand Down Expand Up @@ -116,13 +158,13 @@ def get_model(model_path: Path, dtype: Optional[str], pool: str):
else:
return create_model(DefaultModel, model_path, device, datatype, pool)

if config.model_type == "mistral" and device.type == "hpu":
if config.model_type == "mistral" and device.type == "hpu" and FlashMistral is not None:
try:
return create_model(FlashMistral, model_path, device, datatype, pool)
except FileNotFoundError:
return create_model(DefaultModel, model_path, device, datatype, pool)

if config.model_type == "qwen3" and device.type == "hpu":
if config.model_type == "qwen3" and device.type == "hpu" and FlashQwen3 is not None:
try:
return create_model(FlashQwen3, model_path, device, datatype, pool)
except FileNotFoundError:
Expand Down
Loading
Loading