-
Notifications
You must be signed in to change notification settings - Fork 30.2k
Fix RuntimeError when loading quantized models with int8 weights (#39366) #40090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix RuntimeError when loading quantized models with int8 weights (#39366) #40090
Conversation
Skip weight initialization for int8/uint8 quantized weights in _init_weights method. The normal_() function only works with floating-point tensors, but quantized models contain int8/uint8 weights which should preserve their loaded values. Fixes huggingface#39366 - Add dtype check before calling normal_() on weights - Skip initialization for int8/uint8 weights and biases - Add debug logging when skipping quantized weights - Add comprehensive tests for quantized weight handling - Maintain backward compatibility with existing models
cc @MekkCyber for quantization |
I am also encountering this issue :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this 🤗 ! left some comments below
logger.debug(f"Skipping weight initialization for quantized module {module.__class__.__name__} with dtype {module.weight.dtype}") | ||
else: | ||
module.weight.data.normal_(mean=0.0, std=std) | ||
if module.bias is not None and module.bias.dtype not in (torch.int8, torch.uint8): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need this for .zero_()
too or only .normal_()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thx for the review. @MekkCyber Yes, we should apply the same dtype check for .zero_() as well. Good point. Looking at the code, there are several places where .zero_() is called on weights and biases:
-
- Line 2938: module.bias.data.zero_()
-
- Line 2946: module.weight.data[module.padding_idx].zero_() (for embeddings)
-
- Line 2961: module.bias.data.zero_() (for normalization layers)
The current fix already handles bias at line 2938 with the dtype check, and line 2961 is for normalization layers which typically don't have quantized biases. However, line 2946 for embedding padding_idx could potentially fail with quantized embeddings.
I can update the fix to also check the dtype before calling .zero_() on the padding index for consistency ?
import unittest | ||
import torch | ||
import torch.nn as nn | ||
from transformers import PreTrainedModel, PretrainedConfig | ||
|
||
|
||
class TestQuantizedWeightInitialization(unittest.TestCase): | ||
"""Test that quantized weights are not re-initialized during model loading.""" | ||
|
||
def test_int8_weights_skipped(self): | ||
"""Test that int8 weights are skipped during initialization.""" | ||
|
||
class TestConfig(PretrainedConfig): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.initializer_range = 0.02 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can just add a simple test in tests/quantization/compressed_tensors_integration
with the failling model instead
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, that makes much more sense. I can move the test to tests/quantization/compressed_tensors_integration/ since that's where compressed-tensors related tests belong. I can add a simple test there that reproduces the original failing scenario with the actual quantized model "nm-testing/tinyllama-w8a8-compressed-hf-quantizer" that's already being used in the existing tests. Thank you.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @akacmazz ! still i'm a bit confused why we are trying to initialize weights in the case of quantized models, i mean the weights need to be there already because we are just loading. Will try to take a deeper look because I think the issue is deeper than this
BTW, I also encountered a scenario where I was loading an llm-compressor compressed model and it did not have the weights in int8, but rather it just did not have any weight attribute at all. This is because the weight was still packed and was called
With this, my model skipped initialization for these layers and ran fine. I can give more details if you want. |
class TestConfig(PretrainedConfig): | ||
def __init__(self, **kwargs): | ||
super().__init__(**kwargs) | ||
self.initializer_range = 0.02 | ||
|
||
class TestModel(PreTrainedModel): | ||
config_class = TestConfig | ||
|
||
def __init__(self, config): | ||
super().__init__(config) | ||
self.linear = nn.Linear(10, 10) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey, awesome work on this @akacmazz! Small suggestion, perhaps we could move TestConfig and TestModel to the module level as helper classes. This would make the test suite cleaner and prevent the same code from being defined twice.
Skip weight initialization for int8/uint8 quantized weights in init_weights method.
The normal() function only works with floating-point tensors, but quantized
models contain int8/uint8 weights which should preserve their loaded values.
Fixes #39366
What does this PR do?
Fixes a RuntimeError that occurs when loading llmcompressor W8A8 quantized models. The issue happens because the
_init_weights
method attempts to applynormal_()
distribution to int8 tensors, which PyTorch doesn't support.Before & After
Before (❌)