13
13
from transformers import PretrainedConfig
14
14
15
15
from vllm .config import ModelConfig , PoolerConfig
16
+ from vllm .logger import init_logger
16
17
from vllm .pooling_params import PoolingParams
17
18
from vllm .sequence import PoolerOutput , PoolingSequenceGroupOutput
18
19
from vllm .tasks import PoolingTask
19
20
from vllm .utils import current_stream , resolve_obj_by_qualname
20
21
from vllm .v1 .pool .metadata import PoolingCursor , PoolingMetadata
21
22
23
+ logger = init_logger (__name__ )
24
+
22
25
PoolingFn = Callable [
23
26
[Union [torch .Tensor , list [torch .Tensor ]], PoolingMetadata ],
24
27
Union [torch .Tensor , list [torch .Tensor ]]]
@@ -183,7 +186,7 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
183
186
fn = resolve_obj_by_qualname (function_name )()
184
187
return PoolerActivation .wraps (fn )
185
188
186
- return PoolerScore ()
189
+ return PoolerClassify ()
187
190
188
191
189
192
def build_output (
@@ -371,22 +374,29 @@ def forward_chunk(self, pooled_data: torch.Tensor) -> torch.Tensor:
371
374
372
375
class PoolerClassify (PoolerActivation ):
373
376
374
- def forward_chunk (self , pooled_data : torch .Tensor ) -> torch .Tensor :
375
- num_labels = pooled_data .shape [- 1 ]
376
- if num_labels < 2 :
377
- return F .sigmoid (pooled_data .float ()).to (pooled_data .dtype )
378
-
379
- return F .softmax (pooled_data .float (), dim = - 1 ).to (pooled_data .dtype )
380
-
377
+ def __init__ (self , * , static_num_labels : bool = True ) -> None :
378
+ super ().__init__ ()
381
379
382
- class PoolerScore (PoolerActivation ):
380
+ if static_num_labels :
381
+ from vllm .config import get_current_vllm_config
382
+ vllm_config = get_current_vllm_config ()
383
+ self .num_labels = getattr (vllm_config .model_config .hf_config ,
384
+ "num_labels" , 0 )
385
+ if self .num_labels == 0 :
386
+ logger .warning ("num_labels should be > 0 for classification"
387
+ "models, falling back to softmax. "
388
+ "Please check if the configuration is correct." )
389
+ else :
390
+ self .num_labels = None
383
391
384
392
def forward_chunk (self , pooled_data : torch .Tensor ) -> torch .Tensor :
385
- num_labels = pooled_data .shape [- 1 ]
393
+ num_labels = (self .num_labels if self .num_labels is not None else
394
+ pooled_data .shape [- 1 ])
395
+
386
396
if num_labels < 2 :
387
397
return F .sigmoid (pooled_data .float ()).to (pooled_data .dtype )
388
398
389
- return pooled_data
399
+ return F . softmax ( pooled_data . float (), dim = - 1 ). to ( pooled_data . dtype )
390
400
391
401
392
402
class LambdaPoolerActivation (PoolerActivation ):
@@ -428,6 +438,10 @@ def __init__(self) -> None:
428
438
def forward (self , pooled_data : Union [list [torch .Tensor ], torch .Tensor ],
429
439
pooling_metadata : PoolingMetadata ):
430
440
441
+ if isinstance (pooled_data , list ):
442
+ pooled_data = torch .stack (pooled_data )
443
+ # pooled_data shape: [batchsize, hidden_dimension]
444
+
431
445
# Apply ST projector
432
446
if self .projector is not None :
433
447
projector = cast (nn .Module , self .projector )
@@ -437,17 +451,11 @@ def _proj(x: torch.Tensor) -> torch.Tensor:
437
451
y = projector (x .to (torch .float32 ))
438
452
return y .to (orig_dtype )
439
453
440
- if isinstance (pooled_data , torch .Tensor ):
441
- pooled_data = _proj (pooled_data )
442
- else :
443
- pooled_data = [_proj (t ) for t in pooled_data ]
454
+ pooled_data = _proj (pooled_data )
455
+ # pooled_data shape: [batchsize, embedding_dimension]
444
456
445
457
pooling_params = get_pooling_params (pooling_metadata )
446
458
447
- if isinstance (pooled_data , list ):
448
- pooled_data = torch .stack (pooled_data )
449
- # pooled_data shape: [batchsize, embedding_dimension]
450
-
451
459
# for matryoshka representation
452
460
dimensions_list = [
453
461
pooling_param .dimensions for pooling_param in pooling_params
@@ -477,13 +485,14 @@ def _proj(x: torch.Tensor) -> torch.Tensor:
477
485
for vecs , f in zip (pooled_data , flags )
478
486
]
479
487
488
+ # pooled_data shape: [batchsize, embedding_dimension]
480
489
return pooled_data
481
490
482
491
483
492
class RewardPoolerHead (PoolerHead ):
484
493
485
494
def __init__ (self ) -> None :
486
- super ().__init__ (activation = PoolerClassify ())
495
+ super ().__init__ (activation = PoolerClassify (static_num_labels = False ))
487
496
488
497
def forward (self , pooled_data : Union [list [torch .Tensor ], torch .Tensor ],
489
498
pooling_metadata : PoolingMetadata ):
@@ -637,19 +646,13 @@ def forward(
637
646
pooling_metadata : PoolingMetadata ,
638
647
) -> PoolerOutput :
639
648
pooled_data = self .pooling (hidden_states , pooling_metadata )
640
-
641
649
if isinstance (pooled_data , list ):
642
650
pooled_data = torch .stack (pooled_data )
643
651
# pooled_data shape: [batchsize, hidden_size]
644
652
645
653
if self .classifier is not None :
646
- # apply classifier once on the full batch if possible
647
- if isinstance (pooled_data , torch .Tensor ):
648
- pooled_data = self .classifier (pooled_data )
649
- elif len ({data .shape for data in pooled_data }) <= 1 :
650
- pooled_data = self .classifier (torch .stack (pooled_data ))
651
- else :
652
- pooled_data = [self .classifier (data ) for data in pooled_data ]
654
+ pooled_data = self .classifier (pooled_data )
655
+ # pooled_data shape: [batchsize, num_labels]
653
656
654
657
pooling_params = get_pooling_params (pooling_metadata )
655
658
flags = [p .activation for p in pooling_params ]
@@ -662,6 +665,7 @@ def forward(
662
665
for vecs , f in zip (pooled_data , flags )
663
666
]
664
667
668
+ # scores shape: [batchsize, num_labels]
665
669
return build_output (scores )
666
670
667
671
0 commit comments