14
14
15
15
import inspect
16
16
from typing import Callable , Optional
17
+ from weakref import ref
17
18
18
19
from compressed_tensors .modeling .kvcache import initialize_hooked_kv_cache
19
20
from compressed_tensors .quantization import (
42
43
43
44
44
45
IMPL_ATTR = "impl"
45
- _original_impl = "eager" # mutable, assumes only one model at a time
46
+ HOOKED_ATTENTION_NAME = "ct_hooked_attention"
46
47
47
48
48
49
class QuantizedAttentionImpl (InternalModule ):
@@ -63,7 +64,7 @@ class QuantizedAttentionImpl(InternalModule):
63
64
64
65
def __init__ (self , attn_module : Module ):
65
66
super ().__init__ ()
66
- self .attn_module_container = [ attn_module ] # avoid circular reference
67
+ self .attn_module = ref ( attn_module ) # avoid circular references
67
68
self ._qparams_initialized = False
68
69
69
70
def forward (
@@ -95,13 +96,14 @@ def forward(
95
96
def initialize_qparams_once (self , model : PreTrainedModel , module : Module ):
96
97
"""
97
98
Initialize attention quantization parameters if they have not already been
98
- intialized . KV cache quantization parameters are initialized by the
99
+ initialized . KV cache quantization parameters are initialized by the
99
100
`QuantizedKVCache`
100
101
101
102
:param model: parent model of attention module
102
103
:param module: attention module to initialize with
103
104
"""
104
- assert module is self .attn_module_container [0 ]
105
+ # TODO: move to initialize.py
106
+ assert module is self .attn_module ()
105
107
scheme : Optional [QuantizationScheme ] = getattr (
106
108
module , "quantization_scheme" , None
107
109
)
@@ -142,13 +144,13 @@ def initialize_hooked_attention(
142
144
"""
143
145
if not hasattr (module , IMPL_ATTR ):
144
146
module .register_module (IMPL_ATTR , QuantizedAttentionImpl (module ))
145
- if model .config ._attn_implementation != "ct_hooked_attention" :
147
+ if model .config ._attn_implementation != HOOKED_ATTENTION_NAME :
146
148
# assumes only one model at a time
147
149
global _original_impl
148
150
_original_impl = model .config ._attn_implementation
149
151
150
- AttentionInterface .register ("ct_hooked_attention" , _ct_hooked_attention )
151
- model .config ._attn_implementation = "ct_hooked_attention"
152
+ AttentionInterface .register (HOOKED_ATTENTION_NAME , _ct_hooked_attention )
153
+ model .config ._attn_implementation = HOOKED_ATTENTION_NAME
152
154
153
155
impl : QuantizedAttentionImpl = getattr (module , IMPL_ATTR )
154
156
if quantize :
0 commit comments