18
18
19
19
import json
20
20
import os
21
+ from dataclasses import field
21
22
from enum import Enum
22
23
from typing import Any , Dict , List , Literal , Optional , Union
23
24
24
25
import paddle
25
26
import paddle .distributed as dist
26
27
from paddleformers .transformers .configuration_utils import PretrainedConfig
28
+ from typing_extensions import assert_never
27
29
28
30
import fastdeploy
29
31
from fastdeploy import envs
30
32
from fastdeploy .model_executor .layers .quantization .quant_base import QuantConfigBase
31
33
from fastdeploy .multimodal .registry import MultimodalRegistry
32
34
from fastdeploy .platforms import current_platform
33
35
from fastdeploy .scheduler import SchedulerConfig
36
+ from fastdeploy .transformer_utils .config import get_pooling_config
34
37
from fastdeploy .utils import ceil_div , check_unified_ckpt , get_host_ip , get_logger
35
38
36
39
logger = get_logger ("config" , "config.log" )
37
40
38
- TaskOption = Literal ["generate" ]
41
+ TaskOption = Literal ["auto" , "generate" , "embedding" , "embed" ]
42
+
43
+ RunnerType = Literal ["generate" , "pooling" ]
44
+
45
+ RunnerOption = Literal ["auto" , "generate" , "pooling" ]
46
+
47
+ ConvertOption = Literal ["auto" , "none" , "embed" ]
48
+
49
+ ConvertType = Literal ["none" , "embed" ]
50
+
51
+ _ResolvedTask = Literal ["generate" , "encode" , "embed" ]
52
+
53
+ _RUNNER_CONVERTS : dict [RunnerType , list [ConvertType ]] = {
54
+ "generate" : [],
55
+ "pooling" : ["embed" ],
56
+ }
57
+
58
+ # Some model suffixes are based on auto classes from Transformers:
59
+ # https://huggingface.co/docs/transformers/en/model_doc/auto
60
+ # NOTE: Items higher on this list priority over lower ones
61
+ _SUFFIX_TO_DEFAULTS : list [tuple [str , tuple [RunnerType , ConvertType ]]] = [
62
+ ("ForCausalLM" , ("generate" , "none" )),
63
+ ("ForConditionalGeneration" , ("generate" , "none" )),
64
+ ("ChatModel" , ("generate" , "none" )),
65
+ ("LMHeadModel" , ("generate" , "none" )),
66
+ ("ForTextEncoding" , ("pooling" , "embed" )),
67
+ ("EmbeddingModel" , ("pooling" , "embed" )),
68
+ ("ForSequenceClassification" , ("pooling" , "classify" )),
69
+ ("ForAudioClassification" , ("pooling" , "classify" )),
70
+ ("ForImageClassification" , ("pooling" , "classify" )),
71
+ ("ForVideoClassification" , ("pooling" , "classify" )),
72
+ ("ClassificationModel" , ("pooling" , "classify" )),
73
+ ("ForRewardModeling" , ("pooling" , "reward" )),
74
+ ("RewardModel" , ("pooling" , "reward" )),
75
+ # Let other `*Model`s take priority
76
+ ("Model" , ("pooling" , "embed" )),
77
+ ]
78
+
79
+
80
+ def iter_architecture_defaults ():
81
+ yield from _SUFFIX_TO_DEFAULTS
82
+
83
+
84
+ def try_match_architecture_defaults (
85
+ architecture : str ,
86
+ * ,
87
+ runner_type : Optional [RunnerType ] = None ,
88
+ convert_type : Optional [ConvertType ] = None ,
89
+ ):
90
+ for suffix , (default_runner_type , default_convert_type ) in iter_architecture_defaults ():
91
+ if (
92
+ (runner_type is None or runner_type == default_runner_type )
93
+ and (convert_type is None or convert_type == default_convert_type )
94
+ and architecture .endswith (suffix )
95
+ ):
96
+ return suffix , (default_runner_type , default_convert_type )
97
+ return None
39
98
40
99
41
100
class MoEPhase :
@@ -133,6 +192,12 @@ def __init__(
133
192
self .eos_tokens_lens : int = 2
134
193
self .lm_head_fp32 : bool = False
135
194
self .model_format = "auto"
195
+ self .runner = "auto"
196
+ self .convert = "auto"
197
+ self .pooler_config : Optional ["PoolerConfig" ] = field (init = False )
198
+ self .override_pooler_config : Optional [Union [dict , "PoolerConfig" ]] = None
199
+ self .revision = None
200
+
136
201
self .partial_rotary_factor : float = 1.0
137
202
self .num_nextn_predict_layers = 0
138
203
for key , value in args .items ():
@@ -161,6 +226,7 @@ def __init__(
161
226
self .ori_vocab_size = args .get ("ori_vocab_size" , self .vocab_size )
162
227
163
228
architectures = self .architectures [0 ]
229
+
164
230
if MultimodalRegistry .contains_model (architectures ):
165
231
self .enable_mm = True
166
232
else :
@@ -171,6 +237,43 @@ def __init__(
171
237
self .override_name_from_config ()
172
238
self .read_from_env ()
173
239
self .read_model_config ()
240
+ self .runner_type = self ._get_runner_type (self .architectures , self .runner )
241
+ self .convert_type = self ._get_convert_type (self .architectures , self .runner_type , self .convert )
242
+
243
+ registry = self .registry
244
+ is_generative_model = registry .is_text_generation_model (self .architectures , self )
245
+ is_pooling_model = registry .is_pooling_model (self .architectures , self )
246
+ is_multimodal_model = registry .is_multimodal_model (self .architectures , self )
247
+
248
+ if self .runner_type == "generate" and not is_generative_model :
249
+ if is_multimodal_model :
250
+ pass
251
+ else :
252
+ generate_converts = _RUNNER_CONVERTS ["generate" ]
253
+ if self .convert_type not in generate_converts :
254
+ raise ValueError ("This model does not support '--runner generate." )
255
+ if self .runner_type == "pooling" and not is_pooling_model :
256
+ pooling_converts = _RUNNER_CONVERTS ["pooling" ]
257
+ if self .convert_type not in pooling_converts :
258
+ convert_option = "<" + "|" .join (pooling_converts ) + ">"
259
+ raise ValueError (
260
+ "This model does not support `--runner pooling`. "
261
+ f"You can pass `--convert { convert_option } to adapt "
262
+ "it into a pooling model."
263
+ )
264
+
265
+ self .supported_tasks = self ._get_supported_tasks (self .architectures , self .runner_type , self .convert_type )
266
+ model_info , arch = registry .inspect_model_cls (self .architectures , self )
267
+ self ._model_info = model_info
268
+ self ._architecture = arch
269
+
270
+ self .pooler_config = self ._init_pooler_config ()
271
+
272
+ @property
273
+ def registry (self ):
274
+ from fastdeploy .model_executor .models .model_base import ModelRegistry
275
+
276
+ return ModelRegistry ()
174
277
175
278
def override_name_from_config (self ):
176
279
"""
@@ -194,7 +297,6 @@ def override_name_from_config(self):
194
297
def read_from_env (self ):
195
298
"""
196
299
Read configuration information from environment variables and update the object's attributes.
197
-
198
300
If an attribute is not present or is an empty string in the environment variables, use the default value.
199
301
"""
200
302
self .max_stop_seqs_num = int (envs .FD_MAX_STOP_SEQS_NUM )
@@ -235,6 +337,165 @@ def read_model_config(self):
235
337
f"Config file path: { config_path } "
236
338
)
237
339
340
+ def _get_default_runner_type (
341
+ self ,
342
+ architectures : list [str ],
343
+ ) -> RunnerType :
344
+ registry = self .registry
345
+ if get_pooling_config (self .model , self .revision ):
346
+ return "pooling"
347
+ for arch in architectures :
348
+ if arch in registry .get_supported_archs ():
349
+ if registry .is_pooling_model (architectures , self ):
350
+ return "pooling"
351
+ if registry .is_text_generation_model (architectures , self ):
352
+ return "generate"
353
+ match = try_match_architecture_defaults (arch )
354
+ if match :
355
+ _ , (runner_type , _ ) = match
356
+ return runner_type
357
+ return "generate"
358
+
359
+ def _get_default_convert_type (
360
+ self ,
361
+ architectures : list [str ],
362
+ runner_type : RunnerType ,
363
+ ) -> ConvertType :
364
+ registry = self .registry
365
+
366
+ for arch in architectures :
367
+ if arch in registry .get_supported_archs ():
368
+ if runner_type == "generate" and registry .is_text_generation_model (architectures , self ):
369
+ return "none"
370
+ if runner_type == "pooling" and registry .is_pooling_model (architectures , self ):
371
+ return "none"
372
+ match = try_match_architecture_defaults (arch , runner_type = runner_type )
373
+ if match :
374
+ _ , (_ , convert_type ) = match
375
+ return convert_type
376
+
377
+ # This is to handle Sentence Transformers models that use *ForCausalLM
378
+ # and also multi-modal pooling models which are not defined as
379
+ # Sentence Transformers models
380
+ if runner_type == "pooling" :
381
+ return "embed"
382
+
383
+ return "none"
384
+
385
+ def _get_runner_type (
386
+ self ,
387
+ architectures : list [str ],
388
+ runner : RunnerOption ,
389
+ ) -> RunnerType :
390
+ if runner != "auto" :
391
+ return runner
392
+
393
+ runner_type = self ._get_default_runner_type (architectures )
394
+ if runner_type != "generate" :
395
+ logger .info (
396
+ "Resolved `--runner auto` to `--runner %s`. " "Pass the value explicitly to silence this message." ,
397
+ runner_type ,
398
+ )
399
+
400
+ return runner_type
401
+
402
+ def _get_convert_type (
403
+ self ,
404
+ architectures : list [str ],
405
+ runner_type : RunnerType ,
406
+ convert : ConvertOption ,
407
+ ) -> ConvertType :
408
+ if convert != "auto" :
409
+ return convert
410
+
411
+ convert_type = self ._get_default_convert_type (architectures , runner_type )
412
+
413
+ if convert_type != "none" :
414
+ logger .info (
415
+ "Resolved `--convert auto` to `--convert %s`. " "Pass the value explicitly to silence this message." ,
416
+ convert_type ,
417
+ )
418
+
419
+ return convert_type
420
+
421
+ def _get_supported_generation_tasks (
422
+ self ,
423
+ architectures : list [str ],
424
+ convert_type : ConvertType ,
425
+ ) -> list [_ResolvedTask ]:
426
+ registry = self .registry
427
+
428
+ supported_tasks = list [_ResolvedTask ]()
429
+ if registry .is_text_generation_model (architectures , self ) or convert_type in _RUNNER_CONVERTS ["generate" ]:
430
+ supported_tasks .append ("generate" )
431
+
432
+ # TODO:Temporarily does not support transcription.
433
+ return supported_tasks
434
+
435
+ def _get_default_pooling_task (
436
+ self ,
437
+ architectures : list [str ],
438
+ ) -> Literal ["embed" ]:
439
+ # Temporarily does not support classification and reward.
440
+ for arch in architectures :
441
+ match = try_match_architecture_defaults (arch , runner_type = "pooling" )
442
+ if match :
443
+ _ , (_ , convert_type ) = match
444
+ assert convert_type != "none"
445
+ return convert_type
446
+
447
+ return "embed"
448
+
449
+ def _get_supported_pooling_tasks (
450
+ self ,
451
+ architectures : list [str ],
452
+ convert_type : ConvertType ,
453
+ ) -> list [_ResolvedTask ]:
454
+ registry = self .registry
455
+
456
+ supported_tasks = list [_ResolvedTask ]()
457
+ if registry .is_pooling_model (architectures , self ) or convert_type in _RUNNER_CONVERTS ["pooling" ]:
458
+ supported_tasks .append ("encode" )
459
+
460
+ extra_task = self ._get_default_pooling_task (architectures ) if convert_type == "none" else convert_type
461
+ supported_tasks .append (extra_task )
462
+
463
+ return supported_tasks
464
+
465
+ def _get_supported_tasks (
466
+ self ,
467
+ architectures : list [str ],
468
+ runner_type : RunnerType ,
469
+ convert_type : ConvertType ,
470
+ ) -> list [_ResolvedTask ]:
471
+ if runner_type == "generate" :
472
+ return self ._get_supported_generation_tasks (architectures , convert_type )
473
+ if runner_type == "pooling" :
474
+ return self ._get_supported_pooling_tasks (architectures , convert_type )
475
+
476
+ assert_never (runner_type )
477
+
478
+ def _init_pooler_config (self ) -> Optional ["PoolerConfig" ]:
479
+ if self .runner_type == "pooling" :
480
+ if isinstance (self .override_pooler_config , dict ):
481
+ self .override_pooler_config = PoolerConfig (** self .override_pooler_config )
482
+
483
+ pooler_config = self .override_pooler_config or PoolerConfig ()
484
+
485
+ base_config = get_pooling_config (self .model , self .revision )
486
+ if base_config is not None :
487
+ for k , v in base_config .items ():
488
+ if getattr (pooler_config , k ) is None :
489
+ setattr (pooler_config , k , v )
490
+
491
+ default_pooling_type = self ._model_info .default_pooling_type
492
+ if pooler_config .pooling_type is None :
493
+ pooler_config .pooling_type = default_pooling_type
494
+
495
+ return pooler_config
496
+
497
+ return None
498
+
238
499
def _get_download_model (self , model_name , model_type = "default" ):
239
500
# TODO: Provide dynamic graph for self-downloading and save to the specified download directory.
240
501
pass
@@ -846,6 +1107,41 @@ def __init__(
846
1107
setattr (self , key , value )
847
1108
848
1109
1110
+ class PoolerConfig :
1111
+ """Controls the behavior of output pooling in pooling models."""
1112
+
1113
+ pooling_type : Optional [str ] = None
1114
+ """
1115
+ The pooling method of the pooling model.
1116
+ """
1117
+ # for embeddings models
1118
+ normalize : Optional [bool ] = None
1119
+ """
1120
+ Whether to normalize the embeddings outputs. Defaults to True.
1121
+ """
1122
+ dimensions : Optional [int ] = None
1123
+ """
1124
+ Reduce the dimensions of embeddings if model
1125
+ support matryoshka representation. Defaults to None.
1126
+ """
1127
+ enable_chunked_processing : Optional [bool ] = None
1128
+ """
1129
+ Whether to enable chunked processing for long inputs that exceed the model's
1130
+ maximum position embeddings. When enabled, long inputs will be split into
1131
+ chunks, processed separately, and then aggregated using weighted averaging.
1132
+ This allows embedding models to handle arbitrarily long text without CUDA
1133
+ errors. Defaults to False.
1134
+ """
1135
+ max_embed_len : Optional [int ] = None
1136
+ """
1137
+ Maximum input length allowed for embedding generation. When set, allows
1138
+ inputs longer than max_embed_len to be accepted for embedding models.
1139
+ When an input exceeds max_embed_len, it will be handled according to
1140
+ the original max_model_len validation logic.
1141
+ Defaults to None (i.e. set to max_model_len).
1142
+ """
1143
+
1144
+
849
1145
class LoRAConfig :
850
1146
"""LoRA Config"""
851
1147
0 commit comments