@@ -994,6 +994,8 @@ class DynamicCache(Cache):
994
994
```
995
995
"""
996
996
997
+ _export_registered = False
998
+
997
999
def __init__ (
998
1000
self ,
999
1001
ddp_cache_data : Optional [Iterable [tuple [torch .Tensor , torch .Tensor ]]] = None ,
@@ -1047,6 +1049,34 @@ def __init__(
1047
1049
else :
1048
1050
super ().__init__ (layers = layers )
1049
1051
1052
+ self ._register_export_support ()
1053
+
1054
+ @classmethod
1055
+ def _register_export_support (cls ):
1056
+ """
1057
+ Utilities for `DynamicCache` <> torch.export support
1058
+ """
1059
+ if cls ._export_registered :
1060
+ return
1061
+
1062
+ # Pytree registration causes memory leak for FSDP runs, see here: https://github.com/huggingface/transformers/issues/39795
1063
+ if is_torch_greater_or_equal ("2.3" ) and not is_fsdp_enabled ():
1064
+ torch .utils ._pytree .register_pytree_node (
1065
+ DynamicCache ,
1066
+ lambda dynamic_cache : torch .utils ._pytree ._dict_flatten (cls ._get_cache_dict (dynamic_cache )),
1067
+ cls ._unflatten_dynamic_cache ,
1068
+ serialized_type_name = f"{ DynamicCache .__module__ } .{ DynamicCache .__name__ } " ,
1069
+ flatten_with_keys_fn = lambda dynamic_cache : torch .utils ._pytree ._dict_flatten_with_keys (
1070
+ cls ._get_cache_dict (dynamic_cache )
1071
+ ),
1072
+ )
1073
+ # TODO (tmanlaibaatar) This won't be needed in torch 2.7.
1074
+ torch .fx ._pytree .register_pytree_flatten_spec (
1075
+ DynamicCache , lambda cache , spec : torch .fx ._pytree ._dict_flatten_spec (cls ._get_cache_dict (cache ), spec )
1076
+ )
1077
+
1078
+ cls ._export_registered = True
1079
+
1050
1080
def to_legacy_cache (self ) -> tuple [tuple [torch .Tensor , torch .Tensor ]]:
1051
1081
"""
1052
1082
Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
@@ -1070,12 +1100,9 @@ def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tens
1070
1100
cache .update (key_states , value_states , layer_idx )
1071
1101
return cache
1072
1102
1073
-
1074
- # Utilities for `DynamicCache` <> torch.export support
1075
- # Pytree registration is not supported for FSDP runs, see here: https://github.com/huggingface/transformers/issues/39795
1076
- if is_torch_greater_or_equal ("2.3" ) and not is_fsdp_enabled ():
1077
-
1078
- def _get_cache_dict (cache : DynamicCache ):
1103
+ @staticmethod
1104
+ def _get_cache_dict (cache ):
1105
+ """Convert cache to dictionary format for pytree operations."""
1079
1106
if any (not isinstance (layer , (DynamicLayer , DynamicSlidingWindowLayer )) for layer in cache .layers ):
1080
1107
raise RuntimeError ("This pytree flattening function should only be applied to DynamicCache" )
1081
1108
@@ -1089,12 +1116,10 @@ def _get_cache_dict(cache: DynamicCache):
1089
1116
"value_cache" : [layer .values for layer in cache .layers if layer .values is not None ],
1090
1117
}
1091
1118
1092
- def _unflatten_dynamic_cache (
1093
- values ,
1094
- context : torch .utils ._pytree .Context ,
1095
- ):
1119
+ @classmethod
1120
+ def _unflatten_dynamic_cache (cls , values , context : torch .utils ._pytree .Context ):
1096
1121
dictionary = torch .utils ._pytree ._dict_unflatten (values , context )
1097
- cache = DynamicCache ()
1122
+ cache = cls ()
1098
1123
# Reconstruct layers from keys and values lists
1099
1124
key_list = dictionary .get ("key_cache" , [])
1100
1125
value_list = dictionary .get ("value_cache" , [])
@@ -1104,20 +1129,6 @@ def _unflatten_dynamic_cache(
1104
1129
cache .update (key , value , idx )
1105
1130
return cache
1106
1131
1107
- torch .utils ._pytree .register_pytree_node (
1108
- DynamicCache ,
1109
- lambda dynamic_cache : torch .utils ._pytree ._dict_flatten (_get_cache_dict (dynamic_cache )),
1110
- _unflatten_dynamic_cache ,
1111
- serialized_type_name = f"{ DynamicCache .__module__ } .{ DynamicCache .__name__ } " ,
1112
- flatten_with_keys_fn = lambda dynamic_cache : torch .utils ._pytree ._dict_flatten_with_keys (
1113
- _get_cache_dict (dynamic_cache )
1114
- ),
1115
- )
1116
- # TODO (tmanlaibaatar) This won't be needed in torch 2.7.
1117
- torch .fx ._pytree .register_pytree_flatten_spec (
1118
- DynamicCache , lambda cache , spec : torch .fx ._pytree ._dict_flatten_spec (_get_cache_dict (cache ), spec )
1119
- )
1120
-
1121
1132
1122
1133
class OffloadedCache (Cache ):
1123
1134
"""
0 commit comments