From 18df94fd8e02477c5afc3d35e5e2a413b2edee7e Mon Sep 17 00:00:00 2001 From: akacmazz Date: Mon, 11 Aug 2025 21:39:19 +0300 Subject: [PATCH] Fix RuntimeError when loading quantized models with int8 weights Skip weight initialization for int8/uint8 quantized weights in _init_weights method. The normal_() function only works with floating-point tensors, but quantized models contain int8/uint8 weights which should preserve their loaded values. Fixes #39366 - Add dtype check before calling normal_() on weights - Skip initialization for int8/uint8 weights and biases - Add debug logging when skipping quantized weights - Add comprehensive tests for quantized weight handling - Maintain backward compatibility with existing models --- src/transformers/modeling_utils.py | 18 +++-- tests/test_quantized_weight_initialization.py | 69 +++++++++++++++++++ 2 files changed, 82 insertions(+), 5 deletions(-) create mode 100644 tests/test_quantized_weight_initialization.py diff --git a/src/transformers/modeling_utils.py b/src/transformers/modeling_utils.py index 4d4c3fcbbfd2..97adf2392f09 100644 --- a/src/transformers/modeling_utils.py +++ b/src/transformers/modeling_utils.py @@ -2929,13 +2929,21 @@ def _init_weights(self, module): std = getattr(self.config.get_text_config(), "initializer_range", 0.02) if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)): - module.weight.data.normal_(mean=0.0, std=std) - if module.bias is not None: + # Skip initialization for quantized weights (int8, uint8) + if hasattr(module, "weight") and module.weight.dtype in (torch.int8, torch.uint8): + logger.debug(f"Skipping weight initialization for quantized module {module.__class__.__name__} with dtype {module.weight.dtype}") + else: + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None and module.bias.dtype not in (torch.int8, torch.uint8): module.bias.data.zero_() elif isinstance(module, nn.Embedding): - module.weight.data.normal_(mean=0.0, std=std) - if module.padding_idx is not None: - module.weight.data[module.padding_idx].zero_() + # Skip initialization for quantized embeddings + if hasattr(module, "weight") and module.weight.dtype in (torch.int8, torch.uint8): + logger.debug(f"Skipping weight initialization for quantized embedding with dtype {module.weight.dtype}") + else: + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.MultiheadAttention): # This uses torch's original init module._reset_parameters() diff --git a/tests/test_quantized_weight_initialization.py b/tests/test_quantized_weight_initialization.py new file mode 100644 index 000000000000..c1b162df1676 --- /dev/null +++ b/tests/test_quantized_weight_initialization.py @@ -0,0 +1,69 @@ +import unittest +import torch +import torch.nn as nn +from transformers import PreTrainedModel, PretrainedConfig + + +class TestQuantizedWeightInitialization(unittest.TestCase): + """Test that quantized weights are not re-initialized during model loading.""" + + def test_int8_weights_skipped(self): + """Test that int8 weights are skipped during initialization.""" + + class TestConfig(PretrainedConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.initializer_range = 0.02 + + class TestModel(PreTrainedModel): + config_class = TestConfig + + def __init__(self, config): + super().__init__(config) + self.linear = nn.Linear(10, 10) + # Simulate quantized weights + with torch.no_grad(): + self.linear.weight = nn.Parameter( + self.linear.weight.to(torch.int8), requires_grad=False + ) + + config = TestConfig() + model = TestModel(config) + + # Store original weight + original_weight = model.linear.weight.clone() + + # This should not raise an error and should not modify the weight + model._init_weights(model.linear) + + # Verify weight unchanged and still int8 + self.assertEqual(model.linear.weight.dtype, torch.int8) + self.assertTrue(torch.equal(model.linear.weight, original_weight)) + + def test_float_weights_initialized(self): + """Test that float weights are still properly initialized.""" + + class TestConfig(PretrainedConfig): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.initializer_range = 0.02 + + class TestModel(PreTrainedModel): + config_class = TestConfig + + def __init__(self, config): + super().__init__(config) + self.linear = nn.Linear(10, 10) + + config = TestConfig() + model = TestModel(config) + + # Store original weight + original_weight = model.linear.weight.clone() + + # Initialize weights + model._init_weights(model.linear) + + # Verify weight was modified and remains float32 + self.assertEqual(model.linear.weight.dtype, torch.float32) + self.assertFalse(torch.equal(model.linear.weight, original_weight))