Skip to content

Commit ca7cf03

Browse files
Support infer and deploy of embedding models (#4927)
1 parent fcc1c2d commit ca7cf03

File tree

16 files changed

+335
-78
lines changed

16 files changed

+335
-78
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ You can contact us and communicate with us by adding our group:
7575

7676

7777
## 🎉 News
78+
- 🎁 2025.07.12: Deployment(pt/vLLM/SGLang) of Embedding models is supported, check [here](examples/deploy/embedding/client.py).
7879
- 🎁 2025.07.09: Megatron-SWIFT supports LoRA training. Compared to ms-swift, it achieves significant speedup on MoE models. Training scripts can be found [here](https://github.com/modelscope/ms-swift/blob/main/examples/train/megatron/lora).
7980
- 🎁 2025.06.23: Fine-tuning of reranker models is supported. Training scripts can be found here: [Reranker](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker.sh).
8081
- 🎁 2025.06.18: Support for accelerating the ms-swift [inference](https://github.com/modelscope/ms-swift/blob/main/examples/infer/sglang), deployment, evaluation, and UI modules using the [sglang](https://github.com/sgl-project/sglang) inference acceleration engine. Simply set `--infer_backend sglang` to enable it.

README_CN.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@
7171
- **模型量化**:支持AWQ、GPTQ、FP8和BNB的量化导出,导出的模型支持使用vLLM/SGLang/LmDeploy推理加速,并支持继续训练。
7272

7373
## 🎉 新闻
74+
- 🎁 2025.07.12: 支持部署Embedding模型的部署(pt/vLLM/SGLang), 查看[这里](examples/deploy/embedding/client.py).
7475
- 🎁 2025.07.09: Megatron-SWIFT支持LoRA训练。相比ms-swift,在MoE模型提速显著。训练脚本参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/megatron/lora)
7576
- 🎁 2025.06.23: 支持Reranker模型训练,训练脚本参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/reranker/train_reranker.sh)
7677
- 🎁 2025.06.18: 支持使用[sglang](https://github.com/sgl-project/sglang)推理加速引擎对ms-swift[推理](https://github.com/modelscope/ms-swift/blob/main/examples/infer/sglang)/部署/评测/ui模块进行加速,设置`--infer_backend sglang`即可。

docs/source/BestPractices/Embedding训练.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ SWIFT提供了两个脚手架训练脚本:
107107

108108
## 推理
109109

110-
SWIFT当前没有支持Embedding的模型推理和部署(时间问题),可以使用原模型的代码进行推理:
110+
SWIFT已经支持GME、GTE、Qwen3-Embedding模型的部署,请查看[这里](https://github.com/modelscope/ms-swift/blob/main/examples/deploy/embedding/client.py).
111+
112+
也可以使用原模型的代码进行推理:
111113

112114
https://www.modelscope.cn/models/iic/gte_Qwen2-7B-instruct
113115

docs/source_en/BestPractices/Embedding.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,9 @@ SWIFT provides two scaffold training scripts:
107107

108108
## Inference
109109

110-
SWIFT currently does not support Embedding model inference and deployment (due to time constraints). You can use the original model's code for inference:
110+
SWIFT has supported the deployment of GME、GTE、Qwen3-Embedding models,please check[here](https://github.com/modelscope/ms-swift/blob/main/examples/deploy/embedding/client.py).
111+
112+
You can also use the original model's code for inference:
111113

112114
https://www.modelscope.cn/models/iic/gte_Qwen2-7B-instruct
113115

examples/deploy/embedding/client.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import os
3+
4+
from openai import OpenAI
5+
6+
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
7+
8+
9+
def infer(client, model: str, messages):
10+
# You can also use client.embeddings.create
11+
# But this interface does not support multi-modal medias
12+
resp = client.chat.completions.create(model=model, messages=messages)
13+
emb = resp.data[0]['embedding']
14+
shape = len(emb)
15+
sample = str(emb)
16+
if len(emb) > 6:
17+
sample = str(emb[:3])[:-1] + ', ..., ' + str(emb[-3:])[1:]
18+
print(f'query: {input}')
19+
print(f'Embedding(shape: [1, {shape}]): {sample}')
20+
return emb
21+
22+
23+
def run_client(host: str = '127.0.0.1', port: int = 8000):
24+
client = OpenAI(
25+
api_key='EMPTY',
26+
base_url=f'http://{host}:{port}/v1',
27+
)
28+
model = client.models.list().data[0].id
29+
print(f'model: {model}')
30+
31+
messages = [{
32+
'role':
33+
'user',
34+
'content': [
35+
# {
36+
# 'type': 'image',
37+
# 'image': 'http://modelscope-open.oss-cn-hangzhou.aliyuncs.com/images/animal.png'
38+
# },
39+
{
40+
'type': 'text',
41+
'text': 'What is the capital of China?'
42+
},
43+
]
44+
}]
45+
infer(client, model, messages)
46+
47+
48+
if __name__ == '__main__':
49+
from swift.llm import run_deploy, DeployArguments
50+
with run_deploy(
51+
DeployArguments(
52+
model='Qwen/Qwen3-Embedding-0.6B',
53+
task_type='embedding',
54+
infer_backend='vllm',
55+
verbose=False,
56+
log_interval=-1)) as port:
57+
run_client(port=port)

examples/deploy/embedding/server.sh

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
CUDA_VISIBLE_DEVICES=0 swift deploy \
2+
--host 0.0.0.0 \
3+
--port 8000 \
4+
--model Qwen/Qwen3-Embedding-0.6B \
5+
--infer_backend sglang

swift/llm/argument/infer_args.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ def get_vllm_engine_kwargs(self):
104104
'use_async_engine': self.use_async_engine,
105105
'quantization': self.vllm_quantization,
106106
}
107+
if self.task_type == 'embedding':
108+
kwargs['task_type'] = 'embed'
107109
return kwargs
108110

109111

@@ -135,6 +137,8 @@ def get_sglang_engine_kwargs(self):
135137
'enable_dp_attention': self.sglang_enable_dp_attention,
136138
'disable_custom_all_reduce': self.sglang_disable_custom_all_reduce,
137139
}
140+
if self.task_type == 'embedding':
141+
kwargs['task_type'] = 'embedding'
138142
return kwargs
139143

140144

swift/llm/infer/deploy.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from fastapi.responses import JSONResponse, StreamingResponse
1717

1818
from swift.llm import AdapterRequest, DeployArguments
19-
from swift.llm.infer.protocol import MultiModalRequestMixin
19+
from swift.llm.infer.protocol import EmbeddingRequest, MultiModalRequestMixin
2020
from swift.plugin import InferStats
2121
from swift.utils import JsonlWriter, get_logger
2222
from .infer import SwiftInfer
@@ -34,6 +34,7 @@ def _register_app(self):
3434
self.app.get('/v1/models')(self.get_available_models)
3535
self.app.post('/v1/chat/completions')(self.create_chat_completion)
3636
self.app.post('/v1/completions')(self.create_completion)
37+
self.app.post('/v1/embeddings')(self.create_embedding)
3738

3839
def __init__(self, args: Union[List[str], DeployArguments, None] = None) -> None:
3940
super().__init__(args)
@@ -183,13 +184,20 @@ async def _gen_wrapper():
183184
yield 'data: [DONE]\n\n'
184185

185186
return StreamingResponse(_gen_wrapper(), media_type='text/event-stream')
186-
else:
187+
elif hasattr(res_or_gen, 'choices'):
188+
# instance of ChatCompletionResponse
187189
return self._post_process(request_info, res_or_gen, return_cmpl_response)
190+
else:
191+
return res_or_gen
188192

189193
async def create_completion(self, request: CompletionRequest, raw_request: Request):
190194
chat_request = ChatCompletionRequest.from_cmpl_request(request)
191195
return await self.create_chat_completion(chat_request, raw_request, return_cmpl_response=True)
192196

197+
async def create_embedding(self, request: EmbeddingRequest, raw_request: Request):
198+
chat_request = ChatCompletionRequest.from_cmpl_request(request)
199+
return await self.create_chat_completion(chat_request, raw_request, return_cmpl_response=True)
200+
193201
def run(self):
194202
args = self.args
195203
self.jsonl_writer = JsonlWriter(args.result_path) if args.result_path else None

swift/llm/infer/infer.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,18 @@ def run(self) -> List[Dict[str, Any]]:
9393
logger.info(f'The inference results have been saved to result_path: `{args.result_path}`.')
9494
return result
9595

96+
@staticmethod
97+
def parse_data_from_response(response):
98+
if hasattr(response, 'choices'):
99+
return response.choices[0].message.content
100+
elif hasattr(response, 'data'):
101+
emb = response.data[0].embedding
102+
shape = len(emb)
103+
sample = str(emb)
104+
if len(emb) > 6:
105+
sample = str(emb[:3])[:-1] + ', ..., ' + str(emb[-3:])[1:]
106+
return f'Embedding(shape: [1, {shape}]): {sample}'
107+
96108
def infer_single(self, infer_request: Union[InferRequest, Dict[str, Any]], request_config: RequestConfig) -> str:
97109
res_or_gen = self.infer([infer_request],
98110
request_config,
@@ -107,7 +119,7 @@ def infer_single(self, infer_request: Union[InferRequest, Dict[str, Any]], reque
107119
response += delta
108120
print()
109121
else:
110-
response = res_or_gen.choices[0].message.content
122+
response = self.parse_data_from_response(res_or_gen)
111123
print(response)
112124
print('-' * 50)
113125
return response

swift/llm/infer/infer_engine/pt_engine.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,8 @@
1919
from swift.plugin import Metric
2020
from swift.tuners import Swift
2121
from ..protocol import (ChatCompletionResponse, ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice,
22-
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, RequestConfig, random_uuid)
22+
ChatCompletionStreamResponse, ChatMessage, DeltaMessage, EmbeddingResponse,
23+
EmbeddingResponseData, RequestConfig, random_uuid)
2324
from .infer_engine import InferEngine
2425
from .utils import AdapterRequest, InferStreamer, LogitsStreamer, TokensIteratorStreamer, prepare_generation_config
2526

@@ -325,26 +326,42 @@ def _infer_forward(self,
325326
call_kwargs['adapter_names'] = adapter_names
326327
num_prompt_tokens = self._get_num_tokens(inputs)
327328
inputs.pop('labels', None)
328-
logits = self.model(**inputs, **call_kwargs).logits
329+
output = self.model(**inputs, **call_kwargs)
330+
if hasattr(output, 'logits'):
331+
logits = output.logits
332+
elif 'last_hidden_state' in output:
333+
# embeddings
334+
logits = output['last_hidden_state']
329335
if template.mode == 'seq_cls':
330336
preds, logprobs = template.decode_seq_cls(logits, top_logprobs)
331337
elif template.mode == 'prm':
332338
preds = template.decode_prm(inputs['input_ids'], logits)
333339
logprobs = [None] * len(preds)
340+
elif template.mode == 'embedding':
341+
preds = logits
342+
logprobs = [None] * len(preds)
334343
else:
335344
raise ValueError(f'Unsupported mode: {template.mode}')
336345

337346
res = []
338347
for i, pred in enumerate(preds):
339348
usage_info = self._get_usage_info(num_prompt_tokens, 1)
340-
choices = [
341-
ChatCompletionResponseChoice(
342-
index=0,
343-
message=ChatMessage(role='assistant', content=pred, tool_calls=None),
344-
finish_reason='stop',
345-
logprobs=logprobs[i])
346-
]
347-
res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info))
349+
if template.mode != 'embedding':
350+
choices = [
351+
ChatCompletionResponseChoice(
352+
index=0,
353+
message=ChatMessage(role='assistant', content=pred, tool_calls=None),
354+
finish_reason='stop',
355+
logprobs=logprobs[i])
356+
]
357+
res.append(ChatCompletionResponse(model=self.model_name, choices=choices, usage=usage_info))
358+
else:
359+
res.append(
360+
EmbeddingResponse(
361+
model=self.model_name,
362+
usage=usage_info,
363+
data=[EmbeddingResponseData(embedding=pred.to(torch.float32).cpu().numpy().tolist())]))
364+
348365
return res
349366

350367
def _infer_full(self,
@@ -502,7 +519,8 @@ def _gen_wrapper():
502519
return _gen_wrapper()
503520
else:
504521
if len(kwargs) > 0:
505-
infer_func = self._infer_forward if template.mode in ('seq_cls', 'prm') else self._infer_full
522+
infer_func = self._infer_forward if template.mode in ('seq_cls', 'prm',
523+
'embedding') else self._infer_full
506524
res = infer_func(**kwargs)
507525
else:
508526
res = []

0 commit comments

Comments
 (0)