Commit 49edc46
committed
Fix bf16 dtype mismatch in ZeRO-3 with zero_quantized_weights
When using ZeRO-3 with zero_quantized_weights=True and bf16 enabled,
the dequantized weights were incorrectly cast to fp16 instead of
preserving the original bf16 dtype. This caused RuntimeError during
training with BERT and similar models.
The fix adds original_dtype tracking to AllGatherCoalescedHandle,
mirroring the existing pattern in AllGatherHandle, to ensure weights
are converted back to their original dtype after dequantization.
Fixes deepspeedai#7775
Signed-off-by: juyterman1000 <fastrunner10090@gmail.com>1 parent 43125a7 commit 49edc46
1 file changed
+12
-2
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
713 | 713 | | |
714 | 714 | | |
715 | 715 | | |
| 716 | + | |
716 | 717 | | |
717 | 718 | | |
718 | 719 | | |
| |||
721 | 722 | | |
722 | 723 | | |
723 | 724 | | |
| 725 | + | |
724 | 726 | | |
725 | 727 | | |
726 | 728 | | |
| |||
735 | 737 | | |
736 | 738 | | |
737 | 739 | | |
738 | | - | |
739 | | - | |
| 740 | + | |
| 741 | + | |
| 742 | + | |
| 743 | + | |
| 744 | + | |
| 745 | + | |
| 746 | + | |
740 | 747 | | |
741 | 748 | | |
742 | 749 | | |
| |||
1469 | 1476 | | |
1470 | 1477 | | |
1471 | 1478 | | |
| 1479 | + | |
| 1480 | + | |
1472 | 1481 | | |
1473 | 1482 | | |
1474 | 1483 | | |
1475 | 1484 | | |
1476 | 1485 | | |
1477 | 1486 | | |
1478 | 1487 | | |
| 1488 | + | |
1479 | 1489 | | |
1480 | 1490 | | |
1481 | 1491 | | |
| |||
0 commit comments