@@ -818,7 +818,7 @@ def __init__(self, weight, bias, ori_module, w_qdq, a_qdq):
818818 self .dynamic_quant_weight = False
819819 self .dynamic_quant_tmp_weight = False
820820
821- def forward (self , x ):
821+ def forward (self , x , dtype = None ):
822822 if hasattr (self , 'buf_rotate' ) and self .buf_rotate :
823823 x = self .rotater .rotate (x )
824824
@@ -837,10 +837,20 @@ def forward(self, x):
837837 elif self .dynamic_quant_tmp_weight :
838838 self .tmp_weight = self .w_qdq (self )
839839
840+ org_dtype = self .tmp_weight .data .dtype
841+ if dtype is not None :
842+ self .convert_dtype (dtype )
843+
840844 x = torch .functional .F .linear (x , self .tmp_weight , self .tmp_bias )
841845
846+ self .convert_dtype (org_dtype )
842847 return x
843848
849+ def convert_dtype (self , dtype ):
850+ self .tmp_weight .data = self .tmp_weight .data .to (dtype )
851+ if self .bias is not None :
852+ self .bias .data = self .bias .data .to (dtype )
853+
844854 @classmethod
845855 @torch .no_grad ()
846856 def new (cls , module , w_qdq , a_qdq ):
@@ -964,21 +974,32 @@ def __init__(self, module):
964974 # topk selection algorithm
965975 self .norm_topk_prob = module .config .norm_topk_prob
966976 self .gating_dim = module .config .hidden_size
967- self .fc = nn .Linear (self .gating_dim , self .n_routed_experts , bias = False )
968- self .fc .weight = module .weight
977+ self .fc = getattr (module , 'fc' ,
978+ nn .Linear (self .gating_dim , self .n_routed_experts , bias = False ))
979+ if not hasattr (module , 'fc' ):
980+ self .fc .weight = module .weight
969981
970982 @property
971983 def weight (self ):
972984 return self .fc .weight
973985
986+ def _fp32_forward (self , hidden_states ):
987+ if isinstance (self .fc , tuple (_LLMC_LINEAR_TYPES_ )):
988+ logits = self .fc (hidden_states .type (torch .float32 ), dtype = torch .float32 )
989+ else :
990+ org_dtype = self .fc .weight .dtype
991+ self .fc .weight .data = self .fc .weight .data .to (torch .float32 )
992+ logits = self .fc (hidden_states .type (torch .float32 ))
993+ self .fc .weight .data = self .fc .weight .data .to (org_dtype )
994+ return logits
995+
974996 def forward (self , hidden_states ):
975997 bsz , seq_len , h = hidden_states .shape
976998 # compute gating score
977999 hidden_states = hidden_states .view (- 1 , h )
978- org_dtype = self .fc .weight .dtype
979- self .fc .weight .data = self .fc .weight .data .to (torch .float32 )
980- logits = self .fc (hidden_states .type (torch .float32 ))
981- self .fc .weight .data = self .fc .weight .data .to (org_dtype )
1000+
1001+ logits = self ._fp32_forward (hidden_states )
1002+
9821003 if self .scoring_func == 'softmax' :
9831004 scores = logits .softmax (dim = - 1 , dtype = torch .float32 )
9841005 else :
0 commit comments