Skip to content

Commit 6a8ac7e

Browse files
authored
fix accuracy regression (#1041)
1 parent 8d8a1cd commit 6a8ac7e

File tree

1 file changed

+43
-20
lines changed

1 file changed

+43
-20
lines changed

auto_round/compressors/base.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2380,21 +2380,33 @@ def _quantize_layer(
23802380
tmp_attention_mask = [self.attention_mask[i] for i in indices]
23812381
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device)
23822382
tmp_attention_mask.unsqueeze_(-1)
2383-
else:
2384-
tmp_attention_mask = 1.0
2385-
2386-
if self.amp:
2387-
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
2383+
if self.amp:
2384+
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
2385+
output_q = wrapper_linear(current_input) # pylint: disable=not-callable
2386+
loss = mse_loss( # pylint: disable=not-callable
2387+
(output_q * tmp_attention_mask).to(torch.float32),
2388+
(current_output * tmp_attention_mask).to(torch.float32),
2389+
)
2390+
else:
23882391
output_q = wrapper_linear(current_input) # pylint: disable=not-callable
23892392
loss = mse_loss( # pylint: disable=not-callable
2390-
output_q * tmp_attention_mask, current_output * tmp_attention_mask
2393+
(output_q * tmp_attention_mask).to(torch.float32),
2394+
(current_output * tmp_attention_mask).to(torch.float32),
23912395
)
23922396
else:
2393-
output_q = wrapper_linear(current_input) # pylint: disable=not-callable
2394-
loss = mse_loss( # pylint: disable=not-callable
2395-
output_q.to(torch.float32) * tmp_attention_mask,
2396-
current_output.to(torch.float32) * tmp_attention_mask,
2397-
)
2397+
if self.amp:
2398+
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
2399+
output_q = wrapper_linear(current_input) # pylint: disable=not-callable
2400+
loss = mse_loss( # pylint: disable=not-callable
2401+
output_q.to(torch.float32),
2402+
current_output.to(torch.float32), # mul 1.0 will copy the output
2403+
)
2404+
else:
2405+
output_q = wrapper_linear(current_input) # pylint: disable=not-callable
2406+
loss = mse_loss( # pylint: disable=not-callable
2407+
output_q.to(torch.float32), current_output.to(torch.float32)
2408+
)
2409+
23982410
total_loss += loss.item() / num_elm
23992411

24002412
self._scale_loss_and_backward(scaler, loss)
@@ -2540,18 +2552,29 @@ def _get_loss(
25402552
tmp_attention_mask = [self.attention_mask[i] for i in indices]
25412553
tmp_attention_mask = torch.cat(tmp_attention_mask, dim=0).to(device)
25422554
tmp_attention_mask.unsqueeze_(-1)
2543-
else:
2544-
tmp_attention_mask = 1.0
2545-
if self.amp:
2546-
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
2555+
if self.amp:
2556+
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
2557+
loss = mse_loss( # pylint: disable=not-callable
2558+
(output_q * tmp_attention_mask).to(torch.float32),
2559+
(current_output * tmp_attention_mask).to(torch.float32),
2560+
)
2561+
else:
25472562
loss = mse_loss( # pylint: disable=not-callable
2548-
output_q * tmp_attention_mask, current_output * tmp_attention_mask
2563+
output_q.to(torch.float32) * tmp_attention_mask,
2564+
current_output.to(torch.float32) * tmp_attention_mask,
25492565
)
2566+
25502567
else:
2551-
loss = mse_loss( # pylint: disable=not-callable
2552-
output_q.to(torch.float32) * tmp_attention_mask,
2553-
current_output.to(torch.float32) * tmp_attention_mask,
2554-
)
2568+
if self.amp:
2569+
with autocast(device_type=device.split(":")[0], dtype=self.amp_dtype):
2570+
loss = mse_loss( # pylint: disable=not-callable
2571+
output_q.to(torch.float32), current_output.to(torch.float32)
2572+
)
2573+
else:
2574+
loss = mse_loss( # pylint: disable=not-callable
2575+
output_q.to(torch.float32),
2576+
current_output.to(torch.float32),
2577+
)
25552578
return loss
25562579

25572580
def _quantize_block(

0 commit comments

Comments
 (0)