@@ -2380,21 +2380,33 @@ def _quantize_layer(
23802380 tmp_attention_mask = [self .attention_mask [i ] for i in indices ]
23812381 tmp_attention_mask = torch .cat (tmp_attention_mask , dim = 0 ).to (device )
23822382 tmp_attention_mask .unsqueeze_ (- 1 )
2383- else :
2384- tmp_attention_mask = 1.0
2385-
2386- if self .amp :
2387- with autocast (device_type = device .split (":" )[0 ], dtype = self .amp_dtype ):
2383+ if self .amp :
2384+ with autocast (device_type = device .split (":" )[0 ], dtype = self .amp_dtype ):
2385+ output_q = wrapper_linear (current_input ) # pylint: disable=not-callable
2386+ loss = mse_loss ( # pylint: disable=not-callable
2387+ (output_q * tmp_attention_mask ).to (torch .float32 ),
2388+ (current_output * tmp_attention_mask ).to (torch .float32 ),
2389+ )
2390+ else :
23882391 output_q = wrapper_linear (current_input ) # pylint: disable=not-callable
23892392 loss = mse_loss ( # pylint: disable=not-callable
2390- output_q * tmp_attention_mask , current_output * tmp_attention_mask
2393+ (output_q * tmp_attention_mask ).to (torch .float32 ),
2394+ (current_output * tmp_attention_mask ).to (torch .float32 ),
23912395 )
23922396 else :
2393- output_q = wrapper_linear (current_input ) # pylint: disable=not-callable
2394- loss = mse_loss ( # pylint: disable=not-callable
2395- output_q .to (torch .float32 ) * tmp_attention_mask ,
2396- current_output .to (torch .float32 ) * tmp_attention_mask ,
2397- )
2397+ if self .amp :
2398+ with autocast (device_type = device .split (":" )[0 ], dtype = self .amp_dtype ):
2399+ output_q = wrapper_linear (current_input ) # pylint: disable=not-callable
2400+ loss = mse_loss ( # pylint: disable=not-callable
2401+ output_q .to (torch .float32 ),
2402+ current_output .to (torch .float32 ), # mul 1.0 will copy the output
2403+ )
2404+ else :
2405+ output_q = wrapper_linear (current_input ) # pylint: disable=not-callable
2406+ loss = mse_loss ( # pylint: disable=not-callable
2407+ output_q .to (torch .float32 ), current_output .to (torch .float32 )
2408+ )
2409+
23982410 total_loss += loss .item () / num_elm
23992411
24002412 self ._scale_loss_and_backward (scaler , loss )
@@ -2540,18 +2552,29 @@ def _get_loss(
25402552 tmp_attention_mask = [self .attention_mask [i ] for i in indices ]
25412553 tmp_attention_mask = torch .cat (tmp_attention_mask , dim = 0 ).to (device )
25422554 tmp_attention_mask .unsqueeze_ (- 1 )
2543- else :
2544- tmp_attention_mask = 1.0
2545- if self .amp :
2546- with autocast (device_type = device .split (":" )[0 ], dtype = self .amp_dtype ):
2555+ if self .amp :
2556+ with autocast (device_type = device .split (":" )[0 ], dtype = self .amp_dtype ):
2557+ loss = mse_loss ( # pylint: disable=not-callable
2558+ (output_q * tmp_attention_mask ).to (torch .float32 ),
2559+ (current_output * tmp_attention_mask ).to (torch .float32 ),
2560+ )
2561+ else :
25472562 loss = mse_loss ( # pylint: disable=not-callable
2548- output_q * tmp_attention_mask , current_output * tmp_attention_mask
2563+ output_q .to (torch .float32 ) * tmp_attention_mask ,
2564+ current_output .to (torch .float32 ) * tmp_attention_mask ,
25492565 )
2566+
25502567 else :
2551- loss = mse_loss ( # pylint: disable=not-callable
2552- output_q .to (torch .float32 ) * tmp_attention_mask ,
2553- current_output .to (torch .float32 ) * tmp_attention_mask ,
2554- )
2568+ if self .amp :
2569+ with autocast (device_type = device .split (":" )[0 ], dtype = self .amp_dtype ):
2570+ loss = mse_loss ( # pylint: disable=not-callable
2571+ output_q .to (torch .float32 ), current_output .to (torch .float32 )
2572+ )
2573+ else :
2574+ loss = mse_loss ( # pylint: disable=not-callable
2575+ output_q .to (torch .float32 ),
2576+ current_output .to (torch .float32 ),
2577+ )
25552578 return loss
25562579
25572580 def _quantize_block (
0 commit comments