Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion src/transformers/quantizers/quantizer_bnb_4bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
is_accelerate_available,
is_bitsandbytes_available,
is_torch_available,
is_torch_npu_available,
is_torch_xpu_available,
logging,
)
Expand Down Expand Up @@ -171,6 +172,9 @@ def create_quantized_param(

old_value = getattr(module, tensor_name)

# `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
if isinstance(target_device, int) and is_torch_npu_available():
target_device = f"npu:{target_device}"
if tensor_name == "bias":
if param_value is None:
new_value = old_value.to(target_device)
Expand Down Expand Up @@ -259,11 +263,12 @@ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
torch_dtype = torch.float16
return torch_dtype

# Copied from transformers.quantizers.quantizer_bnb_8bit.Bnb8BitHfQuantizer.update_device_map
def update_device_map(self, device_map):
if device_map is None:
if torch.cuda.is_available():
device_map = {"": torch.cuda.current_device()}
elif is_torch_npu_available():
device_map = {"": f"npu:{torch.npu.current_device()}"}
elif is_torch_xpu_available():
device_map = {"": f"xpu:{torch.xpu.current_device()}"}
else:
Expand Down