Skip to content

Commit 57bb6db

Browse files
authored
Skipping pytree registration in case fsdp is enabled (#40075)
* Skipping pytree registration in case fsdp is enabled * Beauty changes * Beauty changes * Moved the is_fsdp_available function to import utils * Moved is_fsdp_available to integrations.fsdp * Skipping pytree registration in case fsdp is enabled * Beauty changes * Beauty changes * Moved the is_fsdp_available function to import utils * Moved is_fsdp_available to integrations.fsdp * Added pytree registration inside dynamic cache class * Making ci/cd lords happy * Adding a check if DynamicCache is already a leaf * Adding try/catch for multiple initializations of DynamicCache in test suites * Moving dynamic cache pytree registration to executorch * Adding try catch back
1 parent 5b3b7ea commit 57bb6db

File tree

5 files changed

+89
-66
lines changed

5 files changed

+89
-66
lines changed

src/transformers/cache_utils.py

Lines changed: 0 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@
44

55
import torch
66

7-
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6
8-
97
from .configuration_utils import PretrainedConfig
108
from .utils import (
119
is_hqq_available,
@@ -1072,54 +1070,6 @@ def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tens
10721070
return cache
10731071

10741072

1075-
# Utilities for `DynamicCache` <> torch.export support
1076-
1077-
if is_torch_greater_or_equal("2.3"):
1078-
1079-
def _get_cache_dict(cache: DynamicCache):
1080-
if any(not isinstance(layer, (DynamicLayer, DynamicSlidingWindowLayer)) for layer in cache.layers):
1081-
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
1082-
1083-
if not is_torch_greater_or_equal_than_2_6:
1084-
logger.warning_once(
1085-
"DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions."
1086-
)
1087-
1088-
return {
1089-
"key_cache": [layer.keys for layer in cache.layers if layer.keys is not None],
1090-
"value_cache": [layer.values for layer in cache.layers if layer.values is not None],
1091-
}
1092-
1093-
def _unflatten_dynamic_cache(
1094-
values,
1095-
context: torch.utils._pytree.Context,
1096-
):
1097-
dictionary = torch.utils._pytree._dict_unflatten(values, context)
1098-
cache = DynamicCache()
1099-
# Reconstruct layers from keys and values lists
1100-
key_list = dictionary.get("key_cache", [])
1101-
value_list = dictionary.get("value_cache", [])
1102-
for idx in range(max(len(key_list), len(value_list))):
1103-
key = key_list[idx] if idx < len(key_list) else None
1104-
value = value_list[idx] if idx < len(value_list) else None
1105-
cache.update(key, value, idx)
1106-
return cache
1107-
1108-
torch.utils._pytree.register_pytree_node(
1109-
DynamicCache,
1110-
lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)),
1111-
_unflatten_dynamic_cache,
1112-
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
1113-
flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys(
1114-
_get_cache_dict(dynamic_cache)
1115-
),
1116-
)
1117-
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
1118-
torch.fx._pytree.register_pytree_flatten_spec(
1119-
DynamicCache, lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec)
1120-
)
1121-
1122-
11231073
class OffloadedCache(Cache):
11241074
"""
11251075
A drop-in replacement for DynamicCache that conserves accelerator (GPU, XPU) memory at the expense of more CPU memory.

src/transformers/integrations/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@
5555
"eetq": ["replace_with_eetq_linear"],
5656
"fbgemm_fp8": ["FbgemmFp8Linear", "FbgemmFp8Llama4TextExperts", "replace_with_fbgemm_fp8_linear"],
5757
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
58-
"fsdp": ["is_fsdp_managed_module"],
58+
"fsdp": ["is_fsdp_enabled", "is_fsdp_managed_module"],
5959
"ggml": [
6060
"GGUF_CONFIG_MAPPING",
6161
"GGUF_TOKENIZER_MAPPING",
@@ -204,7 +204,7 @@
204204
from .eetq import replace_with_eetq_linear
205205
from .fbgemm_fp8 import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts, replace_with_fbgemm_fp8_linear
206206
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
207-
from .fsdp import is_fsdp_managed_module
207+
from .fsdp import is_fsdp_enabled, is_fsdp_managed_module
208208
from .ggml import (
209209
GGUF_CONFIG_MAPPING,
210210
GGUF_TOKENIZER_MAPPING,

src/transformers/integrations/executorch.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,14 @@
1515

1616
import torch
1717

18-
from ..cache_utils import DynamicCache, EncoderDecoderCache, HybridCache, StaticCache
18+
from ..cache_utils import (
19+
DynamicCache,
20+
DynamicLayer,
21+
DynamicSlidingWindowLayer,
22+
EncoderDecoderCache,
23+
HybridCache,
24+
StaticCache,
25+
)
1926
from ..generation.configuration_utils import GenerationConfig
2027
from ..masking_utils import (
2128
ALL_MASK_ATTENTION_FUNCTIONS,
@@ -24,7 +31,11 @@
2431
prepare_padding_mask,
2532
)
2633
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
27-
from ..pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
34+
from ..pytorch_utils import (
35+
is_torch_greater_or_equal,
36+
is_torch_greater_or_equal_than_2_3,
37+
is_torch_greater_or_equal_than_2_6,
38+
)
2839

2940

3041
# Add this to src/transformers/integrations/executorch.py
@@ -824,6 +835,8 @@ def __init__(self, model, max_static_cache_length, batch_size):
824835
self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, "cpu")
825836
self.cache = EncoderDecoderCache(self.static_cache, DynamicCache())
826837

838+
register_dynamic_cache_export_support()
839+
827840
# Register cache buffers to make them exportable
828841
for i in range(len(self.static_cache)):
829842
self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
@@ -996,6 +1009,8 @@ def export_with_dynamic_cache(
9961009
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
9971010
model.config._attn_implementation = "sdpa_without_vmap"
9981011

1012+
register_dynamic_cache_export_support()
1013+
9991014
with torch.no_grad():
10001015
exported_program = torch.export.export(
10011016
model,
@@ -1011,6 +1026,59 @@ def export_with_dynamic_cache(
10111026
return exported_program
10121027

10131028

1029+
def register_dynamic_cache_export_support():
1030+
"""
1031+
Utilities for `DynamicCache` <> torch.export support
1032+
"""
1033+
1034+
try:
1035+
torch.utils._pytree.register_pytree_node(
1036+
DynamicCache,
1037+
lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)),
1038+
_unflatten_dynamic_cache,
1039+
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
1040+
flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys(
1041+
_get_cache_dict(dynamic_cache)
1042+
),
1043+
)
1044+
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
1045+
torch.fx._pytree.register_pytree_flatten_spec(
1046+
DynamicCache,
1047+
lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec),
1048+
)
1049+
# Catching this in case there are multiple runs for some test runs
1050+
except ValueError as e:
1051+
if "already registered as pytree node" not in str(e):
1052+
raise
1053+
1054+
1055+
def _get_cache_dict(cache: DynamicCache):
1056+
"""Convert cache to dictionary format for pytree operations."""
1057+
if any(not isinstance(layer, (DynamicLayer, DynamicSlidingWindowLayer)) for layer in cache.layers):
1058+
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")
1059+
1060+
if not is_torch_greater_or_equal_than_2_6:
1061+
logging.warning("DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions.")
1062+
1063+
return {
1064+
"key_cache": [layer.keys for layer in cache.layers if layer.keys is not None],
1065+
"value_cache": [layer.values for layer in cache.layers if layer.values is not None],
1066+
}
1067+
1068+
1069+
def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
1070+
dictionary = torch.utils._pytree._dict_unflatten(values, context)
1071+
cache = DynamicCache()
1072+
# Reconstruct layers from keys and values lists
1073+
key_list = dictionary.get("key_cache", [])
1074+
value_list = dictionary.get("value_cache", [])
1075+
for idx in range(max(len(key_list), len(value_list))):
1076+
key = key_list[idx] if idx < len(key_list) else None
1077+
value = value_list[idx] if idx < len(value_list) else None
1078+
cache.update(key, value, idx)
1079+
return cache
1080+
1081+
10141082
def sdpa_mask_without_vmap(
10151083
batch_size: int,
10161084
cache_position: torch.Tensor,

src/transformers/integrations/fsdp.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
# limitations under the License.
1414
from __future__ import annotations
1515

16+
import os
1617
from typing import TYPE_CHECKING
1718

18-
from ..utils import is_torch_available
19+
from ..utils import is_torch_available, strtobool
1920

2021

2122
if TYPE_CHECKING:
@@ -36,3 +37,17 @@ def is_fsdp_managed_module(module: nn.Module) -> bool:
3637
return isinstance(module, torch.distributed.fsdp.FullyShardedDataParallel) or getattr(
3738
module, "_is_fsdp_managed_module", False
3839
)
40+
41+
42+
def is_fsdp_enabled():
43+
if is_torch_available():
44+
import torch
45+
46+
return (
47+
torch.distributed.is_available()
48+
and torch.distributed.is_initialized()
49+
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
50+
and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1
51+
)
52+
53+
return False

src/transformers/modeling_utils.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@
5454
from .distributed import DistributedConfig
5555
from .dynamic_module_utils import custom_object_save
5656
from .generation import CompileConfig, GenerationConfig
57-
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
57+
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled, is_fsdp_enabled
5858
from .integrations.accelerate import find_tied_parameters, init_empty_weights
5959
from .integrations.deepspeed import _load_state_dict_into_zero3_model
6060
from .integrations.eager_paged import eager_paged_attention_forward
@@ -124,7 +124,6 @@
124124
is_torch_xla_available,
125125
is_torch_xpu_available,
126126
logging,
127-
strtobool,
128127
)
129128
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
130129
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
@@ -182,15 +181,6 @@
182181
from torch.distributed.tensor import DTensor
183182

184183

185-
def is_fsdp_enabled():
186-
return (
187-
torch.distributed.is_available()
188-
and torch.distributed.is_initialized()
189-
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
190-
and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1
191-
)
192-
193-
194184
def is_local_dist_rank_0():
195185
return (
196186
torch.distributed.is_available()

0 commit comments

Comments
 (0)