Skip to content
This repository was archived by the owner on Jun 3, 2025. It is now read-only.

Commit 4d44da6

Browse files
authored
set embeddings weight QAT params to correct device in DP mode (#1013) (#1014)
1 parent ea432e9 commit 4d44da6

File tree

1 file changed

+8
-1
lines changed
  • src/sparseml/pytorch/sparsification/quantization

1 file changed

+8
-1
lines changed

src/sparseml/pytorch/sparsification/quantization/helpers.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -796,9 +796,15 @@ def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConf
796796
embedding.weight_fake_quant = qconfig.weight()
797797

798798
def _qat_forward(self, input: torch.Tensor) -> torch.Tensor:
799+
weight = self.weight_fake_quant(self.weight)
800+
if weight.device != input.device:
801+
# torch DataParallel may not pick up overwritten bound method
802+
# send weight to correct device
803+
weight = weight.to(input.device)
804+
799805
return torch.nn.functional.embedding(
800806
input,
801-
self.weight_fake_quant(self.weight),
807+
weight,
802808
self.padding_idx,
803809
self.max_norm,
804810
self.norm_type,
@@ -808,6 +814,7 @@ def _qat_forward(self, input: torch.Tensor) -> torch.Tensor:
808814

809815
# bind qat forward to embedding
810816
qat_forward_bound = _qat_forward.__get__(embedding, embedding.__class__)
817+
embedding.to(embedding.weight.device) # set weight_fake_quant to correct device
811818
setattr(embedding, "forward", qat_forward_bound)
812819

813820

0 commit comments

Comments
 (0)