|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
15 | 15 | from collections import defaultdict
|
16 |
| -from typing import List, Tuple |
| 16 | +from typing import Dict, List, Tuple |
17 | 17 |
|
18 | 18 | import torch
|
19 | 19 | from accelerate.utils import has_offloaded_params
|
@@ -59,20 +59,14 @@ def _tie_offloaded_tensors(model: torch.nn.Module):
|
59 | 59 | :param model: model potentially containing offloaded meta tensors to fix
|
60 | 60 | """
|
61 | 61 |
|
62 |
| - # map from offloaded tensor pointers to module-key locations |
63 |
| - offloaded_ptrs: dict[int, List[Tuple[torch.nn.Module, str]]] = defaultdict(list) |
| 62 | + # ensure that if a location shares an offloaded tensor pointers, that the |
| 63 | + # meta tensor is also identical (assigned to the first instance of parameter) |
| 64 | + ptr_to_meta: Dict[int, torch.nn.Parameter] = dict() |
64 | 65 | for module in model.modules():
|
65 | 66 | if has_offloaded_params(module):
|
66 | 67 | for key, _ in module.named_parameters(recurse=False):
|
67 |
| - param = module._hf_hook.weights_map[key] |
68 |
| - offloaded_ptrs[id(param)].append((module, key)) |
| 68 | + offloaded_ptr = module._hf_hook.weights_map[key].data_ptr() |
69 | 69 |
|
70 |
| - # ensure that if a location shares an offloaded tensor pointers, that the |
71 |
| - # meta tensor is also identical (assigned to the first element of the set) |
72 |
| - for shared_keys in offloaded_ptrs.values(): |
73 |
| - assert len(shared_keys) >= 1 |
74 |
| - first_tensor = getattr(shared_keys[0][0], shared_keys[0][1]) |
75 |
| - assert first_tensor.device.type == "meta" |
76 |
| - for module, key in shared_keys: |
77 |
| - assert getattr(module, key).device.type == "meta" |
78 |
| - setattr(module, key, first_tensor) |
| 70 | + if offloaded_ptr not in ptr_to_meta: |
| 71 | + ptr_to_meta[offloaded_ptr] = getattr(module, key) |
| 72 | + setattr(module, key, ptr_to_meta[offloaded_ptr]) |
0 commit comments