Skip to content

Commit 85ae8ba

Browse files
committed
use set
Signed-off-by: Kyle Sayers <[email protected]>
1 parent c45352e commit 85ae8ba

File tree

1 file changed

+4
-8
lines changed
  • src/compressed_tensors/transform/factory

1 file changed

+4
-8
lines changed

src/compressed_tensors/transform/factory/base.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from abc import ABC, abstractmethod
1616
from collections import defaultdict
17-
from typing import List, Optional, Tuple
17+
from typing import List, Optional, Tuple, Set
1818

1919
import torch
2020
import torch.nn.utils.parametrize as P
@@ -164,10 +164,6 @@ def _update_tied_weights(self):
164164
which is used by transformers to detect and remove shared pointers
165165
during saving
166166
"""
167-
# avoid issues with this method being called twice
168-
for transform in self.transforms:
169-
transform._dynamic_tied_weights_keys = list()
170-
171167
# map from data_ptrs to keys
172168
ptr_to_keys: dict[int, List[Tuple[TransformBase, str]]] = defaultdict(list)
173169
for transform in self.transforms:
@@ -184,7 +180,7 @@ def _update_tied_weights(self):
184180
tensor = getattr(shared_keys[0][0], shared_keys[0][1])
185181

186182
for transform, name in shared_keys:
187-
transform._dynamic_tied_weights_keys.append(name)
183+
transform._dynamic_tied_weights_keys.add(name)
188184
setattr(transform, name, tensor)
189185

190186

@@ -195,11 +191,11 @@ class TransformBase(InternalModule, ABC):
195191

196192
args: TransformArgs
197193
weight: Parameter
198-
_dynamic_tied_weights_keys: List[str]
194+
_dynamic_tied_weights_keys: Set[str]
199195

200196
def __init__(self):
201197
super().__init__()
202-
self._dynamic_tied_weights_keys = list()
198+
self._dynamic_tied_weights_keys = set()
203199

204200
@abstractmethod
205201
def forward(self, value: Tensor) -> Tensor:

0 commit comments

Comments
 (0)