Skip to content

Commit 2a8f6b7

Browse files
committed
pass config for init later
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 2279181 commit 2a8f6b7

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

src/compressed_tensors/modeling/attention.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from torch import Tensor
3232
from torch.nn import Module
3333
from torch.utils.hooks import RemovableHandle
34-
from transformers import AttentionInterface, PreTrainedModel
34+
from transformers import AttentionInterface, PretrainedConfig, PreTrainedModel
3535
from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS
3636

3737

@@ -62,8 +62,9 @@ class QuantizedAttentionImpl(InternalModule):
6262
:param attn_module: parent attention module
6363
"""
6464

65-
def __init__(self, attn_module: Module):
65+
def __init__(self, config: PretrainedConfig, attn_module: Module):
6666
super().__init__()
67+
self.config = config
6768
self.attn_module = ref(attn_module) # avoid circular references
6869
self._qparams_initialized = False
6970

@@ -143,7 +144,7 @@ def initialize_hooked_attention(
143144
:param quantize: initialize attention quantization parameters
144145
"""
145146
if not hasattr(module, IMPL_ATTR):
146-
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(module))
147+
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(model.config, module))
147148
if model.config._attn_implementation != HOOKED_ATTENTION_NAME:
148149
# assumes only one model at a time
149150
global _original_impl

src/compressed_tensors/modeling/kvcache.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from torch import Tensor
2626
from torch.nn import Module
2727
from torch.utils.hooks import RemovableHandle
28-
from transformers import Cache, PreTrainedModel
28+
from transformers import Cache, PretrainedConfig, PreTrainedModel
2929

3030

3131
__all__ = [
@@ -53,8 +53,9 @@ class QuantizedKVCache(InternalModule):
5353
:param attn_module: parent attention module
5454
"""
5555

56-
def __init__(self, attn_module: Module):
56+
def __init__(self, config: PretrainedConfig, attn_module: Module):
5757
super().__init__()
58+
self.config = config
5859
self.attn_module = ref(attn_module) # avoid circular reference
5960
self.past_key_values: Optional[Cache] = None
6061
self._qparams_initialized = False
@@ -134,7 +135,7 @@ def initialize_hooked_kv_cache(
134135
:param quantize: initialize kv cache quantization parameters
135136
"""
136137
if not hasattr(module, KV_CACHE_ATTR):
137-
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(module))
138+
module.register_module(KV_CACHE_ATTR, QuantizedKVCache(model.config, module))
138139
module.register_forward_pre_hook(_kv_cache_attention_hook, with_kwargs=True)
139140

140141
kv_cache: QuantizedKVCache = getattr(module, KV_CACHE_ATTR)

src/compressed_tensors/transform/factory/base.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,7 @@ def apply_to_model(self, model: Module, use_tqdm=True):
108108
for module, arg in tqdm.tqdm(modules_args, desc=desc, disable=(not use_tqdm)):
109109
self._apply_to_module(model, module, arg)
110110

111-
def _apply_to_module(
112-
self, model: Module, module: Module, args: TransformArgs
113-
):
111+
def _apply_to_module(self, model: Module, module: Module, args: TransformArgs):
114112
"""
115113
Create transforms and apply them to the module
116114

0 commit comments

Comments
 (0)