Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2929,13 +2929,21 @@ def _init_weights(self, module):
std = getattr(self.config.get_text_config(), "initializer_range", 0.02)

if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None:
# Skip initialization for quantized weights (int8, uint8)
if hasattr(module, "weight") and module.weight.dtype in (torch.int8, torch.uint8):
logger.debug(f"Skipping weight initialization for quantized module {module.__class__.__name__} with dtype {module.weight.dtype}")
else:
module.weight.data.normal_(mean=0.0, std=std)
if module.bias is not None and module.bias.dtype not in (torch.int8, torch.uint8):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need this for .zero_() too or only .normal_()

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thx for the review. @MekkCyber Yes, we should apply the same dtype check for .zero_() as well. Good point. Looking at the code, there are several places where .zero_() is called on weights and biases:

    1. Line 2938: module.bias.data.zero_()
    1. Line 2946: module.weight.data[module.padding_idx].zero_() (for embeddings)
    1. Line 2961: module.bias.data.zero_() (for normalization layers)

The current fix already handles bias at line 2938 with the dtype check, and line 2961 is for normalization layers which typically don't have quantized biases. However, line 2946 for embedding padding_idx could potentially fail with quantized embeddings.

I can update the fix to also check the dtype before calling .zero_() on the padding index for consistency ?

module.bias.data.zero_()
elif isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
# Skip initialization for quantized embeddings
if hasattr(module, "weight") and module.weight.dtype in (torch.int8, torch.uint8):
logger.debug(f"Skipping weight initialization for quantized embedding with dtype {module.weight.dtype}")
else:
module.weight.data.normal_(mean=0.0, std=std)
if module.padding_idx is not None:
module.weight.data[module.padding_idx].zero_()
elif isinstance(module, nn.MultiheadAttention):
# This uses torch's original init
module._reset_parameters()
Expand Down
69 changes: 69 additions & 0 deletions tests/test_quantized_weight_initialization.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import unittest
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig


class TestQuantizedWeightInitialization(unittest.TestCase):
"""Test that quantized weights are not re-initialized during model loading."""

def test_int8_weights_skipped(self):
"""Test that int8 weights are skipped during initialization."""

class TestConfig(PretrainedConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.initializer_range = 0.02
Comment on lines +1 to +16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can just add a simple test in tests/quantization/compressed_tensors_integration with the failling model instead

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, that makes much more sense. I can move the test to tests/quantization/compressed_tensors_integration/ since that's where compressed-tensors related tests belong. I can add a simple test there that reproduces the original failing scenario with the actual quantized model "nm-testing/tinyllama-w8a8-compressed-hf-quantizer" that's already being used in the existing tests. Thank you.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @akacmazz ! still i'm a bit confused why we are trying to initialize weights in the case of quantized models, i mean the weights need to be there already because we are just loading. Will try to take a deeper look because I think the issue is deeper than this


class TestModel(PreTrainedModel):
config_class = TestConfig

def __init__(self, config):
super().__init__(config)
self.linear = nn.Linear(10, 10)
# Simulate quantized weights
with torch.no_grad():
self.linear.weight = nn.Parameter(
self.linear.weight.to(torch.int8), requires_grad=False
)

config = TestConfig()
model = TestModel(config)

# Store original weight
original_weight = model.linear.weight.clone()

# This should not raise an error and should not modify the weight
model._init_weights(model.linear)

# Verify weight unchanged and still int8
self.assertEqual(model.linear.weight.dtype, torch.int8)
self.assertTrue(torch.equal(model.linear.weight, original_weight))

def test_float_weights_initialized(self):
"""Test that float weights are still properly initialized."""

class TestConfig(PretrainedConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.initializer_range = 0.02

class TestModel(PreTrainedModel):
config_class = TestConfig

def __init__(self, config):
super().__init__(config)
self.linear = nn.Linear(10, 10)
Comment on lines +46 to +56

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey, awesome work on this @akacmazz! Small suggestion, perhaps we could move TestConfig and TestModel to the module level as helper classes. This would make the test suite cleaner and prevent the same code from being defined twice.


config = TestConfig()
model = TestModel(config)

# Store original weight
original_weight = model.linear.weight.clone()

# Initialize weights
model._init_weights(model.linear)

# Verify weight was modified and remains float32
self.assertEqual(model.linear.weight.dtype, torch.float32)
self.assertFalse(torch.equal(model.linear.weight, original_weight))