Skip to content

Commit 773de39

Browse files
committed
use weakref
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 75056bf commit 773de39

File tree

2 files changed

+15
-11
lines changed

2 files changed

+15
-11
lines changed

src/compressed_tensors/modeling/attention.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import inspect
1616
from typing import Callable, Optional
17+
from weakref import ref
1718

1819
from compressed_tensors.modeling.kvcache import initialize_hooked_kv_cache
1920
from compressed_tensors.quantization import (
@@ -42,7 +43,7 @@
4243

4344

4445
IMPL_ATTR = "impl"
45-
_original_impl = "eager" # mutable, assumes only one model at a time
46+
HOOKED_ATTENTION_NAME = "ct_hooked_attention"
4647

4748

4849
class QuantizedAttentionImpl(InternalModule):
@@ -63,7 +64,7 @@ class QuantizedAttentionImpl(InternalModule):
6364

6465
def __init__(self, attn_module: Module):
6566
super().__init__()
66-
self.attn_module_container = [attn_module] # avoid circular reference
67+
self.attn_module = ref(attn_module) # avoid circular references
6768
self._qparams_initialized = False
6869

6970
def forward(
@@ -95,13 +96,14 @@ def forward(
9596
def initialize_qparams_once(self, model: PreTrainedModel, module: Module):
9697
"""
9798
Initialize attention quantization parameters if they have not already been
98-
intialized. KV cache quantization parameters are initialized by the
99+
initialized. KV cache quantization parameters are initialized by the
99100
`QuantizedKVCache`
100101
101102
:param model: parent model of attention module
102103
:param module: attention module to initialize with
103104
"""
104-
assert module is self.attn_module_container[0]
105+
# TODO: move to initialize.py
106+
assert module is self.attn_module()
105107
scheme: Optional[QuantizationScheme] = getattr(
106108
module, "quantization_scheme", None
107109
)
@@ -142,13 +144,13 @@ def initialize_hooked_attention(
142144
"""
143145
if not hasattr(module, IMPL_ATTR):
144146
module.register_module(IMPL_ATTR, QuantizedAttentionImpl(module))
145-
if model.config._attn_implementation != "ct_hooked_attention":
147+
if model.config._attn_implementation != HOOKED_ATTENTION_NAME:
146148
# assumes only one model at a time
147149
global _original_impl
148150
_original_impl = model.config._attn_implementation
149151

150-
AttentionInterface.register("ct_hooked_attention", _ct_hooked_attention)
151-
model.config._attn_implementation = "ct_hooked_attention"
152+
AttentionInterface.register(HOOKED_ATTENTION_NAME, _ct_hooked_attention)
153+
model.config._attn_implementation = HOOKED_ATTENTION_NAME
152154

153155
impl: QuantizedAttentionImpl = getattr(module, IMPL_ATTR)
154156
if quantize:

src/compressed_tensors/modeling/kvcache.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
import inspect
1616
from typing import Callable, Optional, Tuple
17+
from weakref import ref
1718

1819
from compressed_tensors.quantization import QuantizationStrategy, forward_quantize
1920
from compressed_tensors.quantization.lifecycle.initialize import (
@@ -54,7 +55,7 @@ class QuantizedKVCache(InternalModule):
5455

5556
def __init__(self, attn_module: Module):
5657
super().__init__()
57-
self.attn_module_container = [attn_module] # avoid circular reference
58+
self.attn_module = ref(attn_module) # avoid circular reference
5859
self.past_key_values: Optional[Cache] = None
5960
self._qparams_initialized = False
6061

@@ -69,7 +70,7 @@ def forward(
6970
**kwargs,
7071
) -> Tuple[Tensor, Tensor]:
7172
# quantization
72-
module = self.attn_module_container[0]
73+
module = self.attn_module()
7374
quant_args_attr = "quantization_scheme.input_activations"
7475
quant_args = getattr_chain(module, quant_args_attr, None)
7576
quant_enabled = getattr(module, "quantization_enabled", True)
@@ -89,12 +90,13 @@ def forward(
8990
def initialize_qparams_once(self, model: PreTrainedModel, module: Module):
9091
"""
9192
Initialize kv cache quantization parameters if they have not already been
92-
intialized
93+
initialized
9394
9495
:param model: parent model of attention module
9596
:param module: attention module to initialize with
9697
"""
97-
assert module is self.attn_module_container[0]
98+
# TODO: move to initialize.py
99+
assert module is self.attn_module()
98100
scheme = getattr(module, "quantization_scheme", None)
99101
quant_args = getattr(scheme, "input_activations", None)
100102

0 commit comments

Comments
 (0)