Skip to content

Commit f4213d4

Browse files
committed
Added pytree registration inside dynamic cache class
1 parent 5c3b018 commit f4213d4

File tree

1 file changed

+36
-25
lines changed

1 file changed

+36
-25
lines changed

src/transformers/cache_utils.py

Lines changed: 36 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -994,6 +994,8 @@ class DynamicCache(Cache):
994994
```
995995
"""
996996

997+
_export_registered = False
998+
997999
def __init__(
9981000
self,
9991001
ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None,
@@ -1047,6 +1049,34 @@ def __init__(
10471049
else:
10481050
super().__init__(layers=layers)
10491051

1052+
self._register_export_support()
1053+
1054+
@classmethod
1055+
def _register_export_support(cls):
1056+
"""
1057+
Utilities for `DynamicCache` <> torch.export support
1058+
"""
1059+
if cls._export_registered:
1060+
return
1061+
1062+
# Pytree registration causes memory leak for FSDP runs, see here: https://github.com/huggingface/transformers/issues/39795
1063+
if is_torch_greater_or_equal("2.3") and not is_fsdp_enabled():
1064+
torch.utils._pytree.register_pytree_node(
1065+
DynamicCache,
1066+
lambda dynamic_cache: torch.utils._pytree._dict_flatten(cls._get_cache_dict(dynamic_cache)),
1067+
cls._unflatten_dynamic_cache,
1068+
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
1069+
flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys(
1070+
cls._get_cache_dict(dynamic_cache)
1071+
),
1072+
)
1073+
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
1074+
torch.fx._pytree.register_pytree_flatten_spec(
1075+
DynamicCache, lambda cache, spec: torch.fx._pytree._dict_flatten_spec(cls._get_cache_dict(cache), spec)
1076+
)
1077+
1078+
cls._export_registered = True
1079+
10501080
def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]:
10511081
"""
10521082
Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
@@ -1070,12 +1100,9 @@ def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tens
10701100
cache.update(key_states, value_states, layer_idx)
10711101
return cache
10721102

1073-
1074-
# Utilities for `DynamicCache` <> torch.export support
1075-
# Pytree registration is not supported for FSDP runs, see here: https://github.com/huggingface/transformers/issues/39795
1076-
if is_torch_greater_or_equal("2.3") and not is_fsdp_enabled():
1077-
1078-
def _get_cache_dict(cache: DynamicCache):
1103+
@staticmethod
1104+
def _get_cache_dict(cache):
1105+
"""Convert cache to dictionary format for pytree operations."""
10791106
if any(not isinstance(layer, (DynamicLayer, DynamicSlidingWindowLayer)) for layer in cache.layers):
10801107
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
10811108

@@ -1089,12 +1116,10 @@ def _get_cache_dict(cache: DynamicCache):
10891116
"value_cache": [layer.values for layer in cache.layers if layer.values is not None],
10901117
}
10911118

1092-
def _unflatten_dynamic_cache(
1093-
values,
1094-
context: torch.utils._pytree.Context,
1095-
):
1119+
@classmethod
1120+
def _unflatten_dynamic_cache(cls, values, context: torch.utils._pytree.Context):
10961121
dictionary = torch.utils._pytree._dict_unflatten(values, context)
1097-
cache = DynamicCache()
1122+
cache = cls()
10981123
# Reconstruct layers from keys and values lists
10991124
key_list = dictionary.get("key_cache", [])
11001125
value_list = dictionary.get("value_cache", [])
@@ -1104,20 +1129,6 @@ def _unflatten_dynamic_cache(
11041129
cache.update(key, value, idx)
11051130
return cache
11061131

1107-
torch.utils._pytree.register_pytree_node(
1108-
DynamicCache,
1109-
lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)),
1110-
_unflatten_dynamic_cache,
1111-
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
1112-
flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys(
1113-
_get_cache_dict(dynamic_cache)
1114-
),
1115-
)
1116-
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
1117-
torch.fx._pytree.register_pytree_flatten_spec(
1118-
DynamicCache, lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec)
1119-
)
1120-
11211132

11221133
class OffloadedCache(Cache):
11231134
"""

0 commit comments

Comments
 (0)