@@ -42,7 +42,10 @@ impl Qwen3Attention {
42
42
"weight" ,
43
43
) ?;
44
44
let query_bias = if config. attention_bias {
45
- Some ( vb. pp ( "q_proj" ) . get ( hidden_size, "bias" ) ?)
45
+ Some (
46
+ vb. pp ( "q_proj" )
47
+ . get ( num_attention_heads * attention_head_size, "bias" ) ?,
48
+ )
46
49
} else {
47
50
None
48
51
} ;
@@ -85,7 +88,7 @@ impl Qwen3Attention {
85
88
let q_norm = RMSNorm :: load ( vb. pp ( "q_norm" ) , attention_head_size, config. rms_norm_eps ) ?;
86
89
let k_norm = RMSNorm :: load ( vb. pp ( "k_norm" ) , attention_head_size, config. rms_norm_eps ) ?;
87
90
88
- let softmax_scale = ( 1. / ( attention_head_size as f64 ) . sqrt ( ) ) as f32 ;
91
+ let softmax_scale = 1.0 / ( attention_head_size as f64 ) . sqrt ( ) as f32 ;
89
92
90
93
Ok ( Self {
91
94
q_proj,
@@ -148,6 +151,28 @@ impl Qwen3Attention {
148
151
149
152
apply_rotary_inplace ( & q, & k, & cos, & sin, true ) ?;
150
153
154
+ let ( k, v) = if self . num_key_value_heads != self . num_attention_heads {
155
+ if self . num_attention_heads % self . num_key_value_heads != 0 {
156
+ candle:: bail!( "num_attention_heads must be a multiple of num_key_value_heads" ) ;
157
+ }
158
+ let repeat = self . num_attention_heads / self . num_key_value_heads ;
159
+
160
+ let ( total_tokens, n_kv_heads, head_dim) = k. dims3 ( ) ?;
161
+
162
+ let k = k
163
+ . unsqueeze ( 2 ) ?
164
+ . expand ( ( total_tokens, n_kv_heads, repeat, head_dim) ) ?
165
+ . reshape ( ( total_tokens, n_kv_heads * repeat, head_dim) ) ?;
166
+
167
+ let v = v
168
+ . unsqueeze ( 2 ) ?
169
+ . expand ( ( total_tokens, n_kv_heads, repeat, head_dim) ) ?
170
+ . reshape ( ( total_tokens, n_kv_heads * repeat, head_dim) ) ?;
171
+ ( k, v)
172
+ } else {
173
+ ( k, v)
174
+ } ;
175
+
151
176
let attention = flash_attn_varlen (
152
177
& q,
153
178
& k,
@@ -277,101 +302,20 @@ impl Qwen3Layer {
277
302
278
303
let mlp_output = self . mlp . forward ( & normed_attn_res_output) ?;
279
304
280
- Ok ( ( mlp_output, attn_res) )
281
- }
282
- }
283
-
284
- // Define ClassificationHead trait locally (following TEI pattern)
285
- trait ClassificationHead {
286
- fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > ;
287
- }
288
-
289
- // Qwen3 Classification Head implementation
290
- #[ derive( Debug ) ]
291
- struct Qwen3ClassificationHead {
292
- dense : Linear ,
293
- out_proj : Linear ,
294
- activation : HiddenAct ,
295
- span : tracing:: Span ,
296
- }
297
-
298
- impl Qwen3ClassificationHead {
299
- pub fn load ( vb : VarBuilder , config : & Qwen3Config ) -> Result < Self > {
300
- let ( dense, out_proj) = if vb. contains_tensor ( "score.dense.weight" ) {
301
- tracing:: info!( "Loading Qwen3 classifier with score layers" ) ;
302
-
303
- let dense_weight = vb
304
- . pp ( "score.dense" )
305
- . get ( ( config. hidden_size , config. hidden_size ) , "weight" ) ?;
306
- let dense_bias = vb. pp ( "score.dense" ) . get ( config. hidden_size , "bias" ) ?;
307
- let dense = Linear :: new ( dense_weight, Some ( dense_bias) , None ) ;
308
-
309
- let out_proj_weight = vb
310
- . pp ( "score.out_proj" )
311
- . get ( ( 1 , config. hidden_size ) , "weight" ) ?;
312
- let out_proj_bias = vb. pp ( "score.out_proj" ) . get ( 1 , "bias" ) ?;
313
- let out_proj = Linear :: new ( out_proj_weight, Some ( out_proj_bias) , None ) ;
314
-
315
- ( dense, out_proj)
316
- } else if vb. contains_tensor ( "classifier.dense.weight" ) {
317
- tracing:: info!( "Loading Qwen3 classifier with classifier layers" ) ;
318
-
319
- let dense_weight = vb
320
- . pp ( "classifier.dense" )
321
- . get ( ( config. hidden_size , config. hidden_size ) , "weight" ) ?;
322
- let dense_bias = vb. pp ( "classifier.dense" ) . get ( config. hidden_size , "bias" ) ?;
323
- let dense = Linear :: new ( dense_weight, Some ( dense_bias) , None ) ;
324
-
325
- let out_proj_weight = vb
326
- . pp ( "classifier.out_proj" )
327
- . get ( ( 1 , config. hidden_size ) , "weight" ) ?;
328
- let out_proj_bias = vb. pp ( "classifier.out_proj" ) . get ( 1 , "bias" ) ?;
329
- let out_proj = Linear :: new ( out_proj_weight, Some ( out_proj_bias) , None ) ;
330
-
331
- ( dense, out_proj)
332
- } else {
333
- candle:: bail!(
334
- "Classification layers not found in model weights. \
335
- Expected 'score.dense.weight' or 'classifier.dense.weight' for reranker models. \
336
- This model may not be a trained reranker."
337
- ) ;
338
- } ;
339
-
340
- Ok ( Self {
341
- dense,
342
- out_proj,
343
- activation : config. hidden_act . clone ( ) ,
344
- span : tracing:: span!( tracing:: Level :: TRACE , "classifier" ) ,
345
- } )
346
- }
347
- }
348
-
349
- impl ClassificationHead for Qwen3ClassificationHead {
350
- fn forward ( & self , hidden_states : & Tensor ) -> Result < Tensor > {
351
- let _enter = self . span . enter ( ) ;
352
-
353
- // Input is already pooled
354
-
355
- // Apply dense layer with activation
356
- let hidden = self . dense . forward ( hidden_states) ?;
357
- let hidden = self . activation . forward ( & hidden) ?;
358
-
359
- // Project to single score
360
- let score = self . out_proj . forward ( & hidden) ?;
361
-
362
- // Squeeze to remove the last dimension if it's 1
363
- score. squeeze ( candle:: D :: Minus1 )
305
+ let output = ( & mlp_output + & attn_res) ?;
306
+ Ok ( ( output, attn_res) )
364
307
}
365
308
}
366
309
367
310
pub struct FlashQwen3Model {
368
311
embeddings : Embedding ,
312
+ lm_head_weight : Tensor ,
369
313
layers : Vec < Qwen3Layer > ,
370
314
norm : RMSNorm ,
371
315
cos_cache : Tensor ,
372
316
sin_cache : Tensor ,
317
+ model_type : ModelType ,
373
318
pool : Pool ,
374
- classifier : Option < Box < dyn ClassificationHead + Send > > ,
375
319
pub device : Device ,
376
320
377
321
span : tracing:: Span ,
@@ -388,19 +332,12 @@ impl FlashQwen3Model {
388
332
candle:: bail!( "FlashQwen3 requires DType::F16" )
389
333
}
390
334
391
- let ( pool, classifier ) = match model_type {
335
+ let pool = match & model_type {
392
336
ModelType :: Classifier => {
393
- let pool = Pool :: LastToken ;
394
- let classifier: Box < dyn ClassificationHead + Send > =
395
- Box :: new ( Qwen3ClassificationHead :: load ( vb. clone ( ) , config) ?) ;
396
- ( pool, Some ( classifier) )
397
- }
398
- ModelType :: Embedding ( pool) => {
399
- if pool == Pool :: Splade {
400
- candle:: bail!( "`splade` is not supported for Qwen3" )
401
- }
402
- ( pool, None )
337
+ candle:: bail!( "`classifier` model type is not supported for Qwen3" )
403
338
}
339
+ ModelType :: Embedding ( pool) => pool. clone ( ) ,
340
+ ModelType :: ListwiseReranker => Pool :: LastToken ,
404
341
} ;
405
342
406
343
// The Qwen3-Reranker models contain the `model` key
@@ -411,11 +348,13 @@ impl FlashQwen3Model {
411
348
vb
412
349
} ;
413
350
414
- let embeddings = Embedding :: new (
415
- vb. pp ( "embed_tokens" )
416
- . get ( ( config. vocab_size , config. hidden_size ) , "weight" ) ?,
417
- config. hidden_size ,
418
- ) ;
351
+ let embed_weight = vb
352
+ . pp ( "embed_tokens" )
353
+ . get ( ( config. vocab_size , config. hidden_size ) , "weight" ) ?;
354
+
355
+ let embeddings = Embedding :: new ( embed_weight. clone ( ) , config. hidden_size ) ;
356
+
357
+ let lm_head_weight = embed_weight;
419
358
420
359
let layers = ( 0 ..config. num_hidden_layers )
421
360
. map ( |index| Qwen3Layer :: load ( vb. pp ( format ! ( "layers.{index}" ) ) , config) )
@@ -438,12 +377,13 @@ impl FlashQwen3Model {
438
377
439
378
Ok ( Self {
440
379
embeddings,
380
+ lm_head_weight,
441
381
layers,
442
382
norm,
443
383
cos_cache,
444
384
sin_cache,
385
+ model_type,
445
386
pool,
446
- classifier,
447
387
device : vb. device ( ) . clone ( ) ,
448
388
span : tracing:: span!( tracing:: Level :: TRACE , "model" ) ,
449
389
} )
@@ -469,21 +409,19 @@ impl FlashQwen3Model {
469
409
let cos = self . cos_cache . index_select ( & position_ids, 0 ) ?;
470
410
let sin = self . sin_cache . index_select ( & position_ids, 0 ) ?;
471
411
472
- let mut residual = None ;
473
412
for layer in & self . layers {
474
- let ( h, r ) = layer. forward (
413
+ let ( h, _r ) = layer. forward (
475
414
& hidden_states,
476
- residual . as_ref ( ) ,
415
+ None ,
477
416
& cu_seqlens,
478
417
& cos,
479
418
& sin,
480
419
batch. max_length as usize ,
481
420
) ?;
482
421
hidden_states = h;
483
- residual = Some ( r) ;
484
422
}
485
423
486
- let ( outputs, _) = self . norm . forward ( & hidden_states, residual . as_ref ( ) ) ?;
424
+ let ( outputs, _) = self . norm . forward ( & hidden_states, None ) ?;
487
425
488
426
let has_pooling_requests = !batch. pooled_indices . is_empty ( ) ;
489
427
let has_raw_requests = !batch. raw_indices . is_empty ( ) ;
@@ -553,7 +491,8 @@ impl FlashQwen3Model {
553
491
// Concatenate all results
554
492
Some ( Tensor :: cat ( & results?, 0 ) ?)
555
493
} else {
556
- Some ( ( outputs. sum_keepdim ( 0 ) ? / ( batch. max_length as f64 ) ) ?)
494
+ let actual_len = batch. cumulative_seq_lengths [ 1 ] as f64 ;
495
+ Some ( ( outputs. sum_keepdim ( 0 ) ? / actual_len) ?)
557
496
}
558
497
}
559
498
Pool :: Splade => {
@@ -607,21 +546,64 @@ impl Model for FlashQwen3Model {
607
546
}
608
547
609
548
fn predict ( & self , batch : Batch ) -> Result < Tensor > {
610
- match & self . classifier {
611
- None => candle:: bail!( "`predict` is not implemented for this model" ) ,
612
- Some ( classifier) => {
613
- // Run forward pass to get hidden states
614
- let ( pooled_embeddings, _) = self . forward ( batch) ?;
615
- match pooled_embeddings {
616
- Some ( embeddings) => {
617
- let scores = classifier. forward ( & embeddings) ?;
618
- // Apply sigmoid to convert logits to probabilities
619
- let probabilities = candle_nn:: ops:: sigmoid ( & scores) ?;
620
- Ok ( probabilities)
621
- }
622
- None => candle:: bail!( "No pooled embeddings returned for classification" ) ,
549
+ match & self . model_type {
550
+ ModelType :: ListwiseReranker => {
551
+ let _enter = self . span . enter ( ) ;
552
+
553
+ let batch_size = batch. cumulative_seq_lengths . len ( ) - 1 ;
554
+ let shape = batch. input_ids . len ( ) ;
555
+
556
+ let input_ids = Tensor :: from_vec ( batch. input_ids , shape, & self . device ) ?;
557
+ let position_ids = Tensor :: from_vec ( batch. position_ids , shape, & self . device ) ?;
558
+ let cu_seqlens = Tensor :: from_vec (
559
+ batch. cumulative_seq_lengths . clone ( ) ,
560
+ batch_size + 1 ,
561
+ & self . device ,
562
+ ) ?;
563
+
564
+ let mut hidden_states = self . embeddings . forward ( & input_ids) ?;
565
+
566
+ let cos = self . cos_cache . index_select ( & position_ids, 0 ) ?;
567
+ let sin = self . sin_cache . index_select ( & position_ids, 0 ) ?;
568
+
569
+ for layer in & self . layers {
570
+ let ( h, _r) = layer. forward (
571
+ & hidden_states,
572
+ None ,
573
+ & cu_seqlens,
574
+ & cos,
575
+ & sin,
576
+ batch. max_length as usize ,
577
+ ) ?;
578
+ hidden_states = h;
623
579
}
580
+
581
+ let ( outputs, _) = self . norm . forward ( & hidden_states, None ) ?;
582
+
583
+ let mut last_hidden_states = Vec :: with_capacity ( batch_size) ;
584
+
585
+ for i in 0 ..batch_size {
586
+ let seq_end = batch. cumulative_seq_lengths [ i + 1 ] as usize ;
587
+ let last_token_idx = seq_end - 1 ;
588
+
589
+ let h_last = outputs. i ( last_token_idx) ?; // [hidden_size]
590
+ last_hidden_states. push ( h_last) ;
591
+ }
592
+
593
+ let h_last = Tensor :: stack ( & last_hidden_states, 0 ) ?; // [bs, hidden_size]
594
+
595
+ let true_id = 9693u32 ;
596
+ let false_id = 2152u32 ;
597
+
598
+ let ids = Tensor :: from_vec ( vec ! [ false_id, true_id] , 2 , & self . device ) ?;
599
+ let w = self . lm_head_weight . index_select ( & ids, 0 ) ?; // [2, hidden_size]
600
+ let logits = h_last. matmul ( & w. t ( ) ?) ?; // [bs, 2] (no, yes)
601
+ let log_probs = candle_nn:: ops:: log_softmax ( & logits, D :: Minus1 ) ?;
602
+ let scores = log_probs. i ( ( .., 1 ) ) ?. exp ( ) ?; // P("yes") ∈ (0,1)
603
+
604
+ Ok ( scores)
624
605
}
606
+ _ => candle:: bail!( "`predict` is only available for ModelType::ListwiseReranker" ) ,
625
607
}
626
608
}
627
609
}
0 commit comments