Skip to content

Commit c86945e

Browse files
authored
[Feature] support pool (#3827)
* support pool * update pooling * add pooler_config and check * update * support AutoWeightsLoader load weight * fix * update * delete print * update pre-commit * fix * fix xpu * fix ModelRegistry->model_registry * fix Copilot review * fix pooler.py * delete StepPooler * fix abstract * fix default_loader_v1 * fix Pre Commit * support torch qwen3 dense * add test and fix torch-qwen * fix * fix * adapter ci: * fix review * fix pooling_params.py * fix * fix tasks.py 2025 * fix print and logger * Modefy ModelRegistry and delete AutoWeightsLoader * fix logger * fix test_embedding * fix ci bug * ernie4_5 model_registry * fix test * support Qwen3-Embedding-0.6B tp=1 load * fix extra code * fix * delete fix vocab_size * delete prepare_params_dict * fix:
1 parent da74a5f commit c86945e

36 files changed

+2371
-51
lines changed

docs/features/plugins.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ Assuming you have a custom model class `MyModelForCasualLM` and a pretrained cla
1818

1919
```python
2020
# File: fd_add_dummy_model/__init__.py or fd_add_dummy_model/register.py
21-
from fastdeploy.model_registry import ModelRegistry
21+
from fastdeploy.model_executor.models.model_base import ModelRegistry
2222
from my_custom_model import MyModelForCasualLM, MyPretrainedModel
2323
from fastdeploy.config import ErnieArchitectures
2424

docs/zh/features/plugins.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ FastDeploy 利用 Python 的 `entry_points` 机制来发现并加载插件。开
1818

1919
```python
2020
# 文件:fd_add_dummy_model/__init__.py
21-
from fastdeploy.model_registry import ModelRegistry
21+
from fastdeploy.model_executor.models.model_base import ModelRegistry
2222
from my_custom_model import MyModelForCasualLM, MyPretrainedModel
2323

2424
def register():

fastdeploy/config.py

Lines changed: 298 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,83 @@
1818

1919
import json
2020
import os
21+
from dataclasses import field
2122
from enum import Enum
2223
from typing import Any, Dict, List, Literal, Optional, Union
2324

2425
import paddle
2526
import paddle.distributed as dist
2627
from paddleformers.transformers.configuration_utils import PretrainedConfig
28+
from typing_extensions import assert_never
2729

2830
import fastdeploy
2931
from fastdeploy import envs
3032
from fastdeploy.model_executor.layers.quantization.quant_base import QuantConfigBase
3133
from fastdeploy.multimodal.registry import MultimodalRegistry
3234
from fastdeploy.platforms import current_platform
3335
from fastdeploy.scheduler import SchedulerConfig
36+
from fastdeploy.transformer_utils.config import get_pooling_config
3437
from fastdeploy.utils import ceil_div, check_unified_ckpt, get_host_ip, get_logger
3538

3639
logger = get_logger("config", "config.log")
3740

38-
TaskOption = Literal["generate"]
41+
TaskOption = Literal["auto", "generate", "embedding", "embed"]
42+
43+
RunnerType = Literal["generate", "pooling"]
44+
45+
RunnerOption = Literal["auto", "generate", "pooling"]
46+
47+
ConvertOption = Literal["auto", "none", "embed"]
48+
49+
ConvertType = Literal["none", "embed"]
50+
51+
_ResolvedTask = Literal["generate", "encode", "embed"]
52+
53+
_RUNNER_CONVERTS: dict[RunnerType, list[ConvertType]] = {
54+
"generate": [],
55+
"pooling": ["embed"],
56+
}
57+
58+
# Some model suffixes are based on auto classes from Transformers:
59+
# https://huggingface.co/docs/transformers/en/model_doc/auto
60+
# NOTE: Items higher on this list priority over lower ones
61+
_SUFFIX_TO_DEFAULTS: list[tuple[str, tuple[RunnerType, ConvertType]]] = [
62+
("ForCausalLM", ("generate", "none")),
63+
("ForConditionalGeneration", ("generate", "none")),
64+
("ChatModel", ("generate", "none")),
65+
("LMHeadModel", ("generate", "none")),
66+
("ForTextEncoding", ("pooling", "embed")),
67+
("EmbeddingModel", ("pooling", "embed")),
68+
("ForSequenceClassification", ("pooling", "classify")),
69+
("ForAudioClassification", ("pooling", "classify")),
70+
("ForImageClassification", ("pooling", "classify")),
71+
("ForVideoClassification", ("pooling", "classify")),
72+
("ClassificationModel", ("pooling", "classify")),
73+
("ForRewardModeling", ("pooling", "reward")),
74+
("RewardModel", ("pooling", "reward")),
75+
# Let other `*Model`s take priority
76+
("Model", ("pooling", "embed")),
77+
]
78+
79+
80+
def iter_architecture_defaults():
81+
yield from _SUFFIX_TO_DEFAULTS
82+
83+
84+
def try_match_architecture_defaults(
85+
architecture: str,
86+
*,
87+
runner_type: Optional[RunnerType] = None,
88+
convert_type: Optional[ConvertType] = None,
89+
):
90+
for suffix, (default_runner_type, default_convert_type) in iter_architecture_defaults():
91+
if (
92+
(runner_type is None or runner_type == default_runner_type)
93+
and (convert_type is None or convert_type == default_convert_type)
94+
and architecture.endswith(suffix)
95+
):
96+
return suffix, (default_runner_type, default_convert_type)
97+
return None
3998

4099

41100
class MoEPhase:
@@ -133,6 +192,12 @@ def __init__(
133192
self.eos_tokens_lens: int = 2
134193
self.lm_head_fp32: bool = False
135194
self.model_format = "auto"
195+
self.runner = "auto"
196+
self.convert = "auto"
197+
self.pooler_config: Optional["PoolerConfig"] = field(init=False)
198+
self.override_pooler_config: Optional[Union[dict, "PoolerConfig"]] = None
199+
self.revision = None
200+
136201
self.partial_rotary_factor: float = 1.0
137202
self.num_nextn_predict_layers = 0
138203
for key, value in args.items():
@@ -161,6 +226,7 @@ def __init__(
161226
self.ori_vocab_size = args.get("ori_vocab_size", self.vocab_size)
162227

163228
architectures = self.architectures[0]
229+
164230
if MultimodalRegistry.contains_model(architectures):
165231
self.enable_mm = True
166232
else:
@@ -171,6 +237,43 @@ def __init__(
171237
self.override_name_from_config()
172238
self.read_from_env()
173239
self.read_model_config()
240+
self.runner_type = self._get_runner_type(self.architectures, self.runner)
241+
self.convert_type = self._get_convert_type(self.architectures, self.runner_type, self.convert)
242+
243+
registry = self.registry
244+
is_generative_model = registry.is_text_generation_model(self.architectures, self)
245+
is_pooling_model = registry.is_pooling_model(self.architectures, self)
246+
is_multimodal_model = registry.is_multimodal_model(self.architectures, self)
247+
248+
if self.runner_type == "generate" and not is_generative_model:
249+
if is_multimodal_model:
250+
pass
251+
else:
252+
generate_converts = _RUNNER_CONVERTS["generate"]
253+
if self.convert_type not in generate_converts:
254+
raise ValueError("This model does not support '--runner generate.")
255+
if self.runner_type == "pooling" and not is_pooling_model:
256+
pooling_converts = _RUNNER_CONVERTS["pooling"]
257+
if self.convert_type not in pooling_converts:
258+
convert_option = "<" + "|".join(pooling_converts) + ">"
259+
raise ValueError(
260+
"This model does not support `--runner pooling`. "
261+
f"You can pass `--convert {convert_option} to adapt "
262+
"it into a pooling model."
263+
)
264+
265+
self.supported_tasks = self._get_supported_tasks(self.architectures, self.runner_type, self.convert_type)
266+
model_info, arch = registry.inspect_model_cls(self.architectures, self)
267+
self._model_info = model_info
268+
self._architecture = arch
269+
270+
self.pooler_config = self._init_pooler_config()
271+
272+
@property
273+
def registry(self):
274+
from fastdeploy.model_executor.models.model_base import ModelRegistry
275+
276+
return ModelRegistry()
174277

175278
def override_name_from_config(self):
176279
"""
@@ -194,7 +297,6 @@ def override_name_from_config(self):
194297
def read_from_env(self):
195298
"""
196299
Read configuration information from environment variables and update the object's attributes.
197-
198300
If an attribute is not present or is an empty string in the environment variables, use the default value.
199301
"""
200302
self.max_stop_seqs_num = int(envs.FD_MAX_STOP_SEQS_NUM)
@@ -235,6 +337,165 @@ def read_model_config(self):
235337
f"Config file path: {config_path}"
236338
)
237339

340+
def _get_default_runner_type(
341+
self,
342+
architectures: list[str],
343+
) -> RunnerType:
344+
registry = self.registry
345+
if get_pooling_config(self.model, self.revision):
346+
return "pooling"
347+
for arch in architectures:
348+
if arch in registry.get_supported_archs():
349+
if registry.is_pooling_model(architectures, self):
350+
return "pooling"
351+
if registry.is_text_generation_model(architectures, self):
352+
return "generate"
353+
match = try_match_architecture_defaults(arch)
354+
if match:
355+
_, (runner_type, _) = match
356+
return runner_type
357+
return "generate"
358+
359+
def _get_default_convert_type(
360+
self,
361+
architectures: list[str],
362+
runner_type: RunnerType,
363+
) -> ConvertType:
364+
registry = self.registry
365+
366+
for arch in architectures:
367+
if arch in registry.get_supported_archs():
368+
if runner_type == "generate" and registry.is_text_generation_model(architectures, self):
369+
return "none"
370+
if runner_type == "pooling" and registry.is_pooling_model(architectures, self):
371+
return "none"
372+
match = try_match_architecture_defaults(arch, runner_type=runner_type)
373+
if match:
374+
_, (_, convert_type) = match
375+
return convert_type
376+
377+
# This is to handle Sentence Transformers models that use *ForCausalLM
378+
# and also multi-modal pooling models which are not defined as
379+
# Sentence Transformers models
380+
if runner_type == "pooling":
381+
return "embed"
382+
383+
return "none"
384+
385+
def _get_runner_type(
386+
self,
387+
architectures: list[str],
388+
runner: RunnerOption,
389+
) -> RunnerType:
390+
if runner != "auto":
391+
return runner
392+
393+
runner_type = self._get_default_runner_type(architectures)
394+
if runner_type != "generate":
395+
logger.info(
396+
"Resolved `--runner auto` to `--runner %s`. " "Pass the value explicitly to silence this message.",
397+
runner_type,
398+
)
399+
400+
return runner_type
401+
402+
def _get_convert_type(
403+
self,
404+
architectures: list[str],
405+
runner_type: RunnerType,
406+
convert: ConvertOption,
407+
) -> ConvertType:
408+
if convert != "auto":
409+
return convert
410+
411+
convert_type = self._get_default_convert_type(architectures, runner_type)
412+
413+
if convert_type != "none":
414+
logger.info(
415+
"Resolved `--convert auto` to `--convert %s`. " "Pass the value explicitly to silence this message.",
416+
convert_type,
417+
)
418+
419+
return convert_type
420+
421+
def _get_supported_generation_tasks(
422+
self,
423+
architectures: list[str],
424+
convert_type: ConvertType,
425+
) -> list[_ResolvedTask]:
426+
registry = self.registry
427+
428+
supported_tasks = list[_ResolvedTask]()
429+
if registry.is_text_generation_model(architectures, self) or convert_type in _RUNNER_CONVERTS["generate"]:
430+
supported_tasks.append("generate")
431+
432+
# TODO:Temporarily does not support transcription.
433+
return supported_tasks
434+
435+
def _get_default_pooling_task(
436+
self,
437+
architectures: list[str],
438+
) -> Literal["embed"]:
439+
# Temporarily does not support classification and reward.
440+
for arch in architectures:
441+
match = try_match_architecture_defaults(arch, runner_type="pooling")
442+
if match:
443+
_, (_, convert_type) = match
444+
assert convert_type != "none"
445+
return convert_type
446+
447+
return "embed"
448+
449+
def _get_supported_pooling_tasks(
450+
self,
451+
architectures: list[str],
452+
convert_type: ConvertType,
453+
) -> list[_ResolvedTask]:
454+
registry = self.registry
455+
456+
supported_tasks = list[_ResolvedTask]()
457+
if registry.is_pooling_model(architectures, self) or convert_type in _RUNNER_CONVERTS["pooling"]:
458+
supported_tasks.append("encode")
459+
460+
extra_task = self._get_default_pooling_task(architectures) if convert_type == "none" else convert_type
461+
supported_tasks.append(extra_task)
462+
463+
return supported_tasks
464+
465+
def _get_supported_tasks(
466+
self,
467+
architectures: list[str],
468+
runner_type: RunnerType,
469+
convert_type: ConvertType,
470+
) -> list[_ResolvedTask]:
471+
if runner_type == "generate":
472+
return self._get_supported_generation_tasks(architectures, convert_type)
473+
if runner_type == "pooling":
474+
return self._get_supported_pooling_tasks(architectures, convert_type)
475+
476+
assert_never(runner_type)
477+
478+
def _init_pooler_config(self) -> Optional["PoolerConfig"]:
479+
if self.runner_type == "pooling":
480+
if isinstance(self.override_pooler_config, dict):
481+
self.override_pooler_config = PoolerConfig(**self.override_pooler_config)
482+
483+
pooler_config = self.override_pooler_config or PoolerConfig()
484+
485+
base_config = get_pooling_config(self.model, self.revision)
486+
if base_config is not None:
487+
for k, v in base_config.items():
488+
if getattr(pooler_config, k) is None:
489+
setattr(pooler_config, k, v)
490+
491+
default_pooling_type = self._model_info.default_pooling_type
492+
if pooler_config.pooling_type is None:
493+
pooler_config.pooling_type = default_pooling_type
494+
495+
return pooler_config
496+
497+
return None
498+
238499
def _get_download_model(self, model_name, model_type="default"):
239500
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
240501
pass
@@ -846,6 +1107,41 @@ def __init__(
8461107
setattr(self, key, value)
8471108

8481109

1110+
class PoolerConfig:
1111+
"""Controls the behavior of output pooling in pooling models."""
1112+
1113+
pooling_type: Optional[str] = None
1114+
"""
1115+
The pooling method of the pooling model.
1116+
"""
1117+
# for embeddings models
1118+
normalize: Optional[bool] = None
1119+
"""
1120+
Whether to normalize the embeddings outputs. Defaults to True.
1121+
"""
1122+
dimensions: Optional[int] = None
1123+
"""
1124+
Reduce the dimensions of embeddings if model
1125+
support matryoshka representation. Defaults to None.
1126+
"""
1127+
enable_chunked_processing: Optional[bool] = None
1128+
"""
1129+
Whether to enable chunked processing for long inputs that exceed the model's
1130+
maximum position embeddings. When enabled, long inputs will be split into
1131+
chunks, processed separately, and then aggregated using weighted averaging.
1132+
This allows embedding models to handle arbitrarily long text without CUDA
1133+
errors. Defaults to False.
1134+
"""
1135+
max_embed_len: Optional[int] = None
1136+
"""
1137+
Maximum input length allowed for embedding generation. When set, allows
1138+
inputs longer than max_embed_len to be accepted for embedding models.
1139+
When an input exceeds max_embed_len, it will be handled according to
1140+
the original max_model_len validation logic.
1141+
Defaults to None (i.e. set to max_model_len).
1142+
"""
1143+
1144+
8491145
class LoRAConfig:
8501146
"""LoRA Config"""
8511147

0 commit comments

Comments
 (0)