Skip to content

Commit 72da75f

Browse files
committed
Fix bf16 dtype mismatch in AllGatherHandle quantization path and add regression test
- Update AllGatherHandle.wait() to restore original_dtype after dequantization - Pass original_dtype when instantiating AllGatherHandle for quantized parameters - Add regression test for bf16 + zero_quantized_weights configuration Fixes non-coalesced version of issue deepspeedai#7775 Signed-off-by: juyterman1000 <fastrunner10090@gmail.com>
1 parent 49edc46 commit 72da75f

File tree

2 files changed

+65
-5
lines changed

2 files changed

+65
-5
lines changed

deepspeed/runtime/zero/partition_parameters.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -696,8 +696,11 @@ def wait(self, handle_dependency=True) -> None:
696696
self.__original_dtype).to(self.__param.device)
697697
elif self.__quantization:
698698
instrument_w_nvtx(self.__quantization.quant_handle.wait)()
699-
self.__param.data = self.__quantization.backend.dequantize(
700-
self.__quantization.quantized_param, self.__quantization.scale_buffer).to(self.__param.device)
699+
dequantized = self.__quantization.backend.dequantize(self.__quantization.quantized_param,
700+
self.__quantization.scale_buffer)
701+
if self.__original_dtype is not None:
702+
dequantized = dequantized.to(self.__original_dtype)
703+
self.__param.data = dequantized.to(self.__param.device)
701704
self.__param.ds_status = ZeroParamStatus.AVAILABLE
702705

703706

@@ -739,8 +742,8 @@ def wait(self, handle_dependency=True) -> None:
739742
instrument_w_nvtx(self.quantization.quant_handle.wait)()
740743
# Fix for issue #7775: convert dequantized tensor back to original dtype (e.g., bf16)
741744
# to prevent dtype mismatch when zero_quantized_weights is used with bf16
742-
dequantized = self.quantization.backend.dequantize(
743-
self.quantization.quantized_param, self.quantization.scale_buffer)
745+
dequantized = self.quantization.backend.dequantize(self.quantization.quantized_param,
746+
self.quantization.scale_buffer)
744747
if self.original_dtype is not None:
745748
dequantized = dequantized.to(self.original_dtype)
746749
flat_tensor = dequantized.to(self.params[0].device)
@@ -1385,7 +1388,7 @@ def all_gather_coalesced(params: Iterable[Parameter],
13851388
quant_info.backend = self.quantizer_module
13861389
quant_info.quant_handle = quant_handle
13871390
quant_info.scale_buffer = quant_scale_buffer
1388-
return AllGatherHandle(handle, param, quantization=quant_info)
1391+
return AllGatherHandle(handle, param, quantization=quant_info, original_dtype=original_dtype)
13891392

13901393
else:
13911394
if self.use_all_reduce_for_fetch_params and not quantize and not use_secondary_tensor:
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import pytest
2+
import torch
3+
import deepspeed
4+
from unit.common import DistributedTest
5+
from unit.simple_model import SimpleModel, random_dataloader
6+
7+
8+
class TestZeroQuantBF16(DistributedTest):
9+
world_size = 2
10+
11+
@pytest.mark.parametrize("zero_quantized_weights", [True])
12+
def test_bf16_quantized_weights(self, zero_quantized_weights):
13+
if not deepspeed.get_accelerator().is_bf16_supported():
14+
pytest.skip("bf16 is not supported by this accelerator")
15+
16+
config_dict = {
17+
"train_micro_batch_size_per_gpu": 1,
18+
"zero_optimization": {
19+
"stage": 3,
20+
"zero_quantized_weights": zero_quantized_weights,
21+
},
22+
"bf16": {
23+
"enabled": True
24+
},
25+
"optimizer": {
26+
"type": "Adam",
27+
"params": {
28+
"lr": 1e-3
29+
}
30+
}
31+
}
32+
33+
hidden_dim = 128
34+
model = SimpleModel(hidden_dim=hidden_dim)
35+
model, _, _, _ = deepspeed.initialize(model=model, config=config_dict)
36+
37+
# Ensure model is in bf16
38+
for param in model.parameters():
39+
assert param.dtype == torch.bfloat16
40+
41+
data_loader = random_dataloader(model=model,
42+
total_samples=2,
43+
hidden_dim=hidden_dim,
44+
device=model.device,
45+
dtype=torch.bfloat16)
46+
47+
for n, batch in enumerate(data_loader):
48+
# This triggers all_gather and dequantization
49+
loss = model(batch[0], batch[1])
50+
51+
# Verify that param.data is indeed bfloat16 after all_gather
52+
for name, param in model.named_parameters():
53+
assert param.data.dtype == torch.bfloat16, f"Parameter {name} data dtype is {param.data.dtype}, expected torch.bfloat16"
54+
55+
model.backward(loss)
56+
model.step()
57+
break

0 commit comments

Comments
 (0)