-
Notifications
You must be signed in to change notification settings - Fork 30.3k
Fix RuntimeError when loading quantized models with int8 weights (#39366) #40090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
import unittest | ||
import torch | ||
import torch.nn as nn | ||
from transformers import PreTrainedModel, PretrainedConfig | ||
|
||
|
||
class TestQuantizedWeightInitialization(unittest.TestCase): | ||
"""Test that quantized weights are not re-initialized during model loading.""" | ||
|
||
def test_int8_weights_skipped(self): | ||
"""Test that int8 weights are skipped during initialization.""" | ||
|
||
class TestConfig(PretrainedConfig): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.initializer_range = 0.02 | ||
Comment on lines
+1
to
+16
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we can just add a simple test in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sure, that makes much more sense. I can move the test to tests/quantization/compressed_tensors_integration/ since that's where compressed-tensors related tests belong. I can add a simple test there that reproduces the original failing scenario with the actual quantized model "nm-testing/tinyllama-w8a8-compressed-hf-quantizer" that's already being used in the existing tests. Thank you. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks @akacmazz ! still i'm a bit confused why we are trying to initialize weights in the case of quantized models, i mean the weights need to be there already because we are just loading. Will try to take a deeper look because I think the issue is deeper than this |
||
|
||
class TestModel(PreTrainedModel): | ||
config_class = TestConfig | ||
|
||
def __init__(self, config): | ||
super().__init__(config) | ||
self.linear = nn.Linear(10, 10) | ||
# Simulate quantized weights | ||
with torch.no_grad(): | ||
self.linear.weight = nn.Parameter( | ||
self.linear.weight.to(torch.int8), requires_grad=False | ||
) | ||
|
||
config = TestConfig() | ||
model = TestModel(config) | ||
|
||
# Store original weight | ||
original_weight = model.linear.weight.clone() | ||
|
||
# This should not raise an error and should not modify the weight | ||
model._init_weights(model.linear) | ||
|
||
# Verify weight unchanged and still int8 | ||
self.assertEqual(model.linear.weight.dtype, torch.int8) | ||
self.assertTrue(torch.equal(model.linear.weight, original_weight)) | ||
|
||
def test_float_weights_initialized(self): | ||
"""Test that float weights are still properly initialized.""" | ||
|
||
class TestConfig(PretrainedConfig): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.initializer_range = 0.02 | ||
|
||
class TestModel(PreTrainedModel): | ||
config_class = TestConfig | ||
|
||
def __init__(self, config): | ||
super().__init__(config) | ||
self.linear = nn.Linear(10, 10) | ||
Comment on lines
+46
to
+56
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hey, awesome work on this @akacmazz! Small suggestion, perhaps we could move TestConfig and TestModel to the module level as helper classes. This would make the test suite cleaner and prevent the same code from being defined twice. |
||
|
||
config = TestConfig() | ||
model = TestModel(config) | ||
|
||
# Store original weight | ||
original_weight = model.linear.weight.clone() | ||
|
||
# Initialize weights | ||
model._init_weights(model.linear) | ||
|
||
# Verify weight was modified and remains float32 | ||
self.assertEqual(model.linear.weight.dtype, torch.float32) | ||
self.assertFalse(torch.equal(model.linear.weight, original_weight)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this for
.zero_()
too or only.normal_()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for the review. @MekkCyber Yes, we should apply the same dtype check for .zero_() as well. Good point. Looking at the code, there are several places where .zero_() is called on weights and biases:
The current fix already handles bias at line 2938 with the dtype check, and line 2961 is for normalization layers which typically don't have quantized biases. However, line 2946 for embedding padding_idx could potentially fail with quantized embeddings.
I can update the fix to also check the dtype before calling .zero_() on the padding index for consistency ?