Skip to content

Commit 49edc46

Browse files
committed
Fix bf16 dtype mismatch in ZeRO-3 with zero_quantized_weights
When using ZeRO-3 with zero_quantized_weights=True and bf16 enabled, the dequantized weights were incorrectly cast to fp16 instead of preserving the original bf16 dtype. This caused RuntimeError during training with BERT and similar models. The fix adds original_dtype tracking to AllGatherCoalescedHandle, mirroring the existing pattern in AllGatherHandle, to ensure weights are converted back to their original dtype after dequantization. Fixes deepspeedai#7775 Signed-off-by: juyterman1000 <fastrunner10090@gmail.com>
1 parent 43125a7 commit 49edc46

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

deepspeed/runtime/zero/partition_parameters.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -713,6 +713,7 @@ def __init__(
713713
world_size: int,
714714
use_secondary_tensor=False,
715715
quantization=None,
716+
original_dtype=None,
716717
) -> None:
717718
self.allgather_handle = allgather_handle
718719
self.params = params
@@ -721,6 +722,7 @@ def __init__(
721722
self.use_secondary_tensor = use_secondary_tensor
722723
self.complete = False
723724
self.quantization = quantization
725+
self.original_dtype = original_dtype
724726

725727
for param in self.params:
726728
if param.ds_status != ZeroParamStatus.INFLIGHT:
@@ -735,8 +737,13 @@ def wait(self, handle_dependency=True) -> None:
735737

736738
if self.quantization:
737739
instrument_w_nvtx(self.quantization.quant_handle.wait)()
738-
flat_tensor = self.quantization.backend.dequantize(
739-
self.quantization.quantized_param, self.quantization.scale_buffer).to(self.params[0].device)
740+
# Fix for issue #7775: convert dequantized tensor back to original dtype (e.g., bf16)
741+
# to prevent dtype mismatch when zero_quantized_weights is used with bf16
742+
dequantized = self.quantization.backend.dequantize(
743+
self.quantization.quantized_param, self.quantization.scale_buffer)
744+
if self.original_dtype is not None:
745+
dequantized = dequantized.to(self.original_dtype)
746+
flat_tensor = dequantized.to(self.params[0].device)
740747

741748
self.partitions: List[Parameter] = []
742749
for i in range(self.world_size):
@@ -1469,13 +1476,16 @@ def all_gather_coalesced(params: Iterable[Parameter],
14691476
quant_info.scale_buffer = quant_scale_buffer
14701477
quant_info.partition_sz = partition_sz
14711478
quant_info.world_size = world_size
1479+
# Get the original dtype from param's ds_tensor for proper dtype restoration after dequantization
1480+
original_dtype = params[0].ds_tensor.dtype if params else None
14721481
return AllGatherCoalescedHandle(
14731482
allgather_handle=handle,
14741483
params=params,
14751484
partitions=None,
14761485
world_size=world_size,
14771486
use_secondary_tensor=use_secondary_tensor,
14781487
quantization=quant_info,
1488+
original_dtype=original_dtype,
14791489
)
14801490

14811491
def partition(param_list=None, hierarchy=0, has_been_updated=False, free_data=True):

0 commit comments

Comments
 (0)