Skip to content

Commit 2ef1ab2

Browse files
committed
simplify function
Signed-off-by: Kyle Sayers <[email protected]>
1 parent 33b71b3 commit 2ef1ab2

File tree

1 file changed

+8
-14
lines changed
  • src/compressed_tensors/transform

1 file changed

+8
-14
lines changed

src/compressed_tensors/transform/apply.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
from collections import defaultdict
16-
from typing import List, Tuple
16+
from typing import Dict, List, Tuple
1717

1818
import torch
1919
from accelerate.utils import has_offloaded_params
@@ -59,20 +59,14 @@ def _tie_offloaded_tensors(model: torch.nn.Module):
5959
:param model: model potentially containing offloaded meta tensors to fix
6060
"""
6161

62-
# map from offloaded tensor pointers to module-key locations
63-
offloaded_ptrs: dict[int, List[Tuple[torch.nn.Module, str]]] = defaultdict(list)
62+
# ensure that if a location shares an offloaded tensor pointers, that the
63+
# meta tensor is also identical (assigned to the first instance of parameter)
64+
ptr_to_meta: Dict[int, torch.nn.Parameter] = dict()
6465
for module in model.modules():
6566
if has_offloaded_params(module):
6667
for key, _ in module.named_parameters(recurse=False):
67-
param = module._hf_hook.weights_map[key]
68-
offloaded_ptrs[id(param)].append((module, key))
68+
offloaded_ptr = module._hf_hook.weights_map[key].data_ptr()
6969

70-
# ensure that if a location shares an offloaded tensor pointers, that the
71-
# meta tensor is also identical (assigned to the first element of the set)
72-
for shared_keys in offloaded_ptrs.values():
73-
assert len(shared_keys) >= 1
74-
first_tensor = getattr(shared_keys[0][0], shared_keys[0][1])
75-
assert first_tensor.device.type == "meta"
76-
for module, key in shared_keys:
77-
assert getattr(module, key).device.type == "meta"
78-
setattr(module, key, first_tensor)
70+
if offloaded_ptr not in ptr_to_meta:
71+
ptr_to_meta[offloaded_ptr] = getattr(module, key)
72+
setattr(module, key, ptr_to_meta[offloaded_ptr])

0 commit comments

Comments
 (0)