23
23
from fastdeploy .distributed .communication import tensor_model_parallel_all_reduce
24
24
from fastdeploy .platforms import current_platform
25
25
26
- from ..utils import create_and_set_parameter , get_tensor
26
+ from ..utils import get_tensor
27
27
from .fused_moe_backend_base import UnquantizedFusedMoEMethod
28
28
29
29
if current_platform .is_cuda ():
@@ -202,7 +202,10 @@ def apply_ep_decode(
202
202
gate_out = gate (x .cast ("float32" ))
203
203
# 1. Select topk experts and weights
204
204
topk_idx , topk_weights = self .ep_decoder_runner .moe_select (layer , gate_out )
205
- expertwise_scale = getattr (layer , "up_gate_proj_in_scale_all_experts" , None )
205
+ expertwise_scale = None
206
+ if hasattr (layer , "up_gate_proj_in_scale_all_experts" ): # only use in w4a8
207
+ expertwise_scale = getattr (layer , "up_gate_proj_in_scale_all_experts" , None )
208
+
206
209
# 2. EP Dispatch
207
210
permute_input , token_nums_per_expert , handle = self .ep_decoder_runner .dispatch (
208
211
x , topk_idx , topk_weights , expertwise_scale = expertwise_scale
@@ -382,12 +385,48 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict):
382
385
"down_proj_in_scale" : down_proj_in_scale ,
383
386
}
384
387
for name , tensor in name_tensor_map .items ():
385
- create_and_set_parameter (layer , name , tensor )
388
+ getattr (layer , name ). set_value ( tensor )
386
389
387
- def create_weights (self , layer : nn .Layer , state_dict ):
390
+ def create_weights (self , layer : nn .Layer , ** extra_weight_attrs ):
388
391
"""
389
392
Paddle cutlass create weight process.
390
393
"""
394
+ self .weight_dtype = "int8"
395
+ self .ffn1_weight_shape = [
396
+ layer .num_local_experts ,
397
+ layer .hidden_size // 2 ,
398
+ layer .moe_intermediate_size * 2 ,
399
+ ]
400
+ self .ffn2_weight_shape = [
401
+ layer .num_local_experts ,
402
+ layer .moe_intermediate_size // 2 ,
403
+ layer .hidden_size ,
404
+ ]
405
+ setattr (
406
+ layer ,
407
+ self .added_weight_attrs [0 ],
408
+ layer .create_parameter (
409
+ shape = self .ffn1_weight_shape ,
410
+ dtype = self .weight_dtype ,
411
+ default_initializer = paddle .nn .initializer .Constant (0 ),
412
+ ),
413
+ )
414
+ setattr (
415
+ layer ,
416
+ self .added_weight_attrs [1 ],
417
+ layer .create_parameter (
418
+ shape = self .ffn2_weight_shape ,
419
+ dtype = self .weight_dtype ,
420
+ default_initializer = paddle .nn .initializer .Constant (0 ),
421
+ ),
422
+ )
423
+
424
+ self .create_w4a8_scale_weights (layer , layer .weight_key_map )
425
+
426
+ def process_loaded_weights (self , layer : nn .Layer , state_dict ):
427
+ """
428
+ Paddle cutlass load weight process.
429
+ """
391
430
up_gate_proj_weights , down_proj_weights = layer .extract_moe_ffn_weights (state_dict )
392
431
self .check (layer , up_gate_proj_weights , down_proj_weights )
393
432
for idx , weight_tensor in enumerate ([up_gate_proj_weights , down_proj_weights ]):
@@ -397,11 +436,63 @@ def create_weights(self, layer: nn.Layer, state_dict):
397
436
quant_weight , scale = weight_quantize (weight_tensor [i ], algo = self .moe_quant_type , arch = 80 )
398
437
weight_list .append (quant_weight )
399
438
quanted_weight = paddle .stack (weight_list , axis = 0 )
400
- create_and_set_parameter (layer , weight_name , quanted_weight )
439
+ getattr (layer , weight_name ). set_value ( quanted_weight )
401
440
402
- self .create_w4a8_scale_weights (layer , layer .weight_key_map , state_dict )
441
+ self .load_w4a8_scale_weights (layer , layer .weight_key_map , state_dict )
403
442
404
- def create_w4a8_scale_weights (self , layer : nn .Layer , weight_key_map : dict , state_dict : dict ):
443
+ def create_w4a8_scale_weights (self , layer : nn .Layer , weight_key_map : dict ):
444
+ """
445
+ Get w4a8 weights from state dict and process them.
446
+ Args:
447
+ layer (nn.Layer): The layer to add parameters to.
448
+ weight_key_map (dict): The weight key map.
449
+ state_dict (dict): The state dict.
450
+ """
451
+ self .default_dtype = layer ._helper .get_default_dtype ()
452
+ if layer .ep_size > 1 :
453
+ setattr (
454
+ layer ,
455
+ "up_gate_proj_in_scale_all_experts" ,
456
+ layer .create_parameter (
457
+ shape = [layer .num_experts ],
458
+ dtype = "float32" ,
459
+ default_initializer = paddle .nn .initializer .Constant (0 ),
460
+ ),
461
+ )
462
+
463
+ # in_scales
464
+ for in_scale_name in ["up_gate_proj_in_scale" , "down_proj_in_scale" ]:
465
+ setattr (
466
+ layer ,
467
+ in_scale_name ,
468
+ layer .create_parameter (
469
+ shape = [layer .num_local_experts ],
470
+ dtype = "float32" ,
471
+ default_initializer = paddle .nn .initializer .Constant (0 ),
472
+ ),
473
+ )
474
+
475
+ # weight_scales
476
+ setattr (
477
+ layer ,
478
+ "up_gate_proj_weight_scale" ,
479
+ layer .create_parameter (
480
+ shape = [layer .num_local_experts , layer .moe_intermediate_size * 2 ],
481
+ dtype = self .default_dtype ,
482
+ default_initializer = paddle .nn .initializer .Constant (0 ),
483
+ ),
484
+ )
485
+ setattr (
486
+ layer ,
487
+ "down_proj_weight_scale" ,
488
+ layer .create_parameter (
489
+ shape = [layer .num_local_experts , layer .hidden_size ],
490
+ dtype = self .default_dtype ,
491
+ default_initializer = paddle .nn .initializer .Constant (0 ),
492
+ ),
493
+ )
494
+
495
+ def load_w4a8_scale_weights (self , layer : nn .Layer , weight_key_map : dict , state_dict : dict ):
405
496
"""
406
497
Get w4a8 weights from state dict and process them.
407
498
Args:
@@ -415,7 +506,7 @@ def _extract_scale_tensor(state_dict, key_template, expert_idx):
415
506
416
507
def _process_in_scale (name : str , in_scales : list [paddle .Tensor ]):
417
508
processed_in_scale = 1 / paddle .concat (in_scales )
418
- create_and_set_parameter (layer , name , processed_in_scale )
509
+ getattr (layer , name ). set_value ( processed_in_scale )
419
510
return processed_in_scale
420
511
421
512
def _process_weight_scale (
@@ -426,7 +517,7 @@ def _process_weight_scale(
426
517
processed_weight_scale = (
427
518
paddle .stack (weight_scales , axis = 0 ) / (127 * 112 ) / processed_in_scale [:, None ]
428
519
).cast (paddle .get_default_dtype ())
429
- create_and_set_parameter (layer , name , processed_weight_scale )
520
+ getattr (layer , name ). set_value ( processed_weight_scale )
430
521
431
522
# 1. Init scale containers and maps
432
523
up_gate_proj_weight_scales = []
@@ -456,8 +547,8 @@ def _process_weight_scale(
456
547
for expert_idx in range (layer .num_experts ):
457
548
scale_tensor = get_tensor (state_dict [scale_key_map ["up_gate_proj_in_scale" ].format (expert_idx )])
458
549
up_gate_proj_in_scales_all_experts .append (1 / scale_tensor )
459
- create_and_set_parameter (
460
- layer , "up_gate_proj_in_scale_all_experts" , paddle .concat (up_gate_proj_in_scales_all_experts )
550
+ getattr ( layer , "up_gate_proj_in_scale_all_experts" ). set_value (
551
+ paddle .concat (up_gate_proj_in_scales_all_experts )
461
552
)
462
553
463
554
for local_expert_idx in range (layer .num_local_experts ):
@@ -527,15 +618,85 @@ def process_prequanted_weights(self, layer: nn.Layer, state_dict):
527
618
"down_proj_weight_scale" : down_proj_weight_scale ,
528
619
}
529
620
for name , tensor in name_tensor_map .items ():
530
- create_and_set_parameter (layer , name , tensor )
621
+ getattr (layer , name ). set_value ( tensor )
531
622
532
- def create_weights (self , layer : nn .Layer , state_dict ):
623
+ def create_weights (self , layer : nn .Layer , ** extra_weight_attrs ):
533
624
"""
534
625
Paddle cutlass create weight process.
535
626
"""
627
+ self .default_dtype = layer ._helper .get_default_dtype ()
628
+ self .weight_dtype = "int8"
629
+
630
+ up_gate_proj_weight_name = self .added_weight_attrs [0 ]
631
+ down_proj_weight_name = self .added_weight_attrs [1 ]
632
+ if self .moe_quant_type == "weight_only_int4" :
633
+ self .ffn1_weight_shape = [
634
+ layer .num_local_experts ,
635
+ layer .moe_intermediate_size ,
636
+ layer .hidden_size ,
637
+ ]
638
+ else :
639
+ self .ffn1_weight_shape = [
640
+ layer .num_local_experts ,
641
+ layer .moe_intermediate_size * 2 ,
642
+ layer .hidden_size ,
643
+ ]
644
+ if self .moe_quant_type == "weight_only_int4" :
645
+ self .ffn2_weight_shape = [
646
+ layer .num_local_experts ,
647
+ layer .hidden_size // 2 ,
648
+ layer .moe_intermediate_size ,
649
+ ]
650
+ else :
651
+ self .ffn2_weight_shape = [
652
+ layer .num_local_experts ,
653
+ layer .hidden_size ,
654
+ layer .moe_intermediate_size ,
655
+ ]
656
+ setattr (
657
+ layer ,
658
+ up_gate_proj_weight_name ,
659
+ layer .create_parameter (
660
+ shape = self .ffn1_weight_shape ,
661
+ dtype = self .weight_dtype ,
662
+ default_initializer = paddle .nn .initializer .Constant (0 ),
663
+ ),
664
+ )
665
+ setattr (
666
+ layer ,
667
+ down_proj_weight_name ,
668
+ layer .create_parameter (
669
+ shape = self .ffn2_weight_shape ,
670
+ dtype = self .weight_dtype ,
671
+ default_initializer = paddle .nn .initializer .Constant (0 ),
672
+ ),
673
+ )
674
+ # weight_scale
675
+ setattr (
676
+ layer ,
677
+ self .added_scale_attrs [0 ],
678
+ layer .create_parameter (
679
+ shape = [layer .num_local_experts , layer .moe_intermediate_size * 2 ],
680
+ dtype = self .default_dtype ,
681
+ default_initializer = paddle .nn .initializer .Constant (0 ),
682
+ ),
683
+ )
684
+ setattr (
685
+ layer ,
686
+ self .added_scale_attrs [1 ],
687
+ layer .create_parameter (
688
+ shape = [layer .num_local_experts , layer .hidden_size ],
689
+ dtype = self .default_dtype ,
690
+ default_initializer = paddle .nn .initializer .Constant (0 ),
691
+ ),
692
+ )
693
+
694
+ def process_loaded_weights (self , layer : nn .Layer , state_dict ):
695
+ """
696
+ Paddle cutlass load weight process.
697
+ """
536
698
up_gate_proj_weights , down_proj_weights = layer .extract_moe_ffn_weights (state_dict )
537
699
self .check (layer , up_gate_proj_weights , down_proj_weights )
538
-
539
700
for idx , weight_tensor in enumerate ([up_gate_proj_weights , down_proj_weights ]):
540
701
weight_name = self .added_weight_attrs [idx ]
541
702
scale_name = self .added_scale_attrs [idx ]
@@ -547,7 +708,7 @@ def create_weights(self, layer: nn.Layer, state_dict):
547
708
weight_list .append (quant_weight )
548
709
weight_scale_list .append (scale )
549
710
quanted_weight = paddle .stack (weight_list , axis = 0 )
550
- create_and_set_parameter (layer , weight_name , quanted_weight )
711
+ getattr (layer , weight_name ). set_value ( quanted_weight )
551
712
552
713
quanted_weight_scale = paddle .stack (weight_scale_list , axis = 0 )
553
- create_and_set_parameter (layer , scale_name , quanted_weight_scale )
714
+ getattr (layer , scale_name ). set_value ( quanted_weight_scale )
0 commit comments