Skip to content

Commit 8cb81ac

Browse files
SunMarcWauplin
andauthored
Serialization: take into account meta tensor when splitting the state_dict (#2591)
* Enable meta tensor serialization * getattr is better * style * skip meta tensors * update doc * Update src/huggingface_hub/serialization/_torch.py Co-authored-by: Lucain <[email protected]> * oups --------- Co-authored-by: Lucain <[email protected]>
1 parent 2c7c19d commit 8cb81ac

File tree

1 file changed

+7
-4
lines changed

1 file changed

+7
-4
lines changed

src/huggingface_hub/serialization/_torch.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -368,18 +368,21 @@ def _get_unique_id(tensor: "torch.Tensor") -> Union[int, Tuple[Any, ...]]:
368368
return unique_id
369369

370370

371-
def get_torch_storage_id(tensor: "torch.Tensor") -> Tuple["torch.device", Union[int, Tuple[Any, ...]], int]:
371+
def get_torch_storage_id(tensor: "torch.Tensor") -> Optional[Tuple["torch.device", Union[int, Tuple[Any, ...]], int]]:
372372
"""
373373
Return unique identifier to a tensor storage.
374374
375-
Multiple different tensors can share the same underlying storage. For
376-
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
375+
Multiple different tensors can share the same underlying storage. This identifier is
377376
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
378377
non-overlapping lifetimes may have the same id.
378+
In the case of meta tensors, we return None since we can't tell if they share the same storage.
379379
380380
Taken from https://github.com/huggingface/transformers/blob/1ecf5f7c982d761b4daaa96719d162c324187c64/src/transformers/pytorch_utils.py#L278.
381381
"""
382-
return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor)
382+
if tensor.device.type == "meta":
383+
return None
384+
else:
385+
return tensor.device, _get_unique_id(tensor), get_torch_storage_size(tensor)
383386

384387

385388
def get_torch_storage_size(tensor: "torch.Tensor") -> int:

0 commit comments

Comments
 (0)