Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 0 additions & 50 deletions src/transformers/cache_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@

import torch

from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_6

from .configuration_utils import PretrainedConfig
from .utils import (
is_hqq_available,
Expand Down Expand Up @@ -1072,54 +1070,6 @@ def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tens
return cache


# Utilities for `DynamicCache` <> torch.export support

if is_torch_greater_or_equal("2.3"):

def _get_cache_dict(cache: DynamicCache):
if any(not isinstance(layer, (DynamicLayer, DynamicSlidingWindowLayer)) for layer in cache.layers):
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")

if not is_torch_greater_or_equal_than_2_6:
logger.warning_once(
"DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions."
)

return {
"key_cache": [layer.keys for layer in cache.layers if layer.keys is not None],
"value_cache": [layer.values for layer in cache.layers if layer.values is not None],
}

def _unflatten_dynamic_cache(
values,
context: torch.utils._pytree.Context,
):
dictionary = torch.utils._pytree._dict_unflatten(values, context)
cache = DynamicCache()
# Reconstruct layers from keys and values lists
key_list = dictionary.get("key_cache", [])
value_list = dictionary.get("value_cache", [])
for idx in range(max(len(key_list), len(value_list))):
key = key_list[idx] if idx < len(key_list) else None
value = value_list[idx] if idx < len(value_list) else None
cache.update(key, value, idx)
return cache

torch.utils._pytree.register_pytree_node(
DynamicCache,
lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)),
_unflatten_dynamic_cache,
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys(
_get_cache_dict(dynamic_cache)
),
)
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
torch.fx._pytree.register_pytree_flatten_spec(
DynamicCache, lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec)
)


class OffloadedCache(Cache):
"""
A drop-in replacement for DynamicCache that conserves accelerator (GPU, XPU) memory at the expense of more CPU memory.
Expand Down
4 changes: 2 additions & 2 deletions src/transformers/integrations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@
"eetq": ["replace_with_eetq_linear"],
"fbgemm_fp8": ["FbgemmFp8Linear", "FbgemmFp8Llama4TextExperts", "replace_with_fbgemm_fp8_linear"],
"finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
"fsdp": ["is_fsdp_managed_module"],
"fsdp": ["is_fsdp_enabled", "is_fsdp_managed_module"],
"ggml": [
"GGUF_CONFIG_MAPPING",
"GGUF_TOKENIZER_MAPPING",
Expand Down Expand Up @@ -204,7 +204,7 @@
from .eetq import replace_with_eetq_linear
from .fbgemm_fp8 import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts, replace_with_fbgemm_fp8_linear
from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
from .fsdp import is_fsdp_managed_module
from .fsdp import is_fsdp_enabled, is_fsdp_managed_module
from .ggml import (
GGUF_CONFIG_MAPPING,
GGUF_TOKENIZER_MAPPING,
Expand Down
72 changes: 70 additions & 2 deletions src/transformers/integrations/executorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,14 @@

import torch

from ..cache_utils import DynamicCache, EncoderDecoderCache, HybridCache, StaticCache
from ..cache_utils import (
DynamicCache,
DynamicLayer,
DynamicSlidingWindowLayer,
EncoderDecoderCache,
HybridCache,
StaticCache,
)
from ..generation.configuration_utils import GenerationConfig
from ..masking_utils import (
ALL_MASK_ATTENTION_FUNCTIONS,
Expand All @@ -24,7 +31,11 @@
prepare_padding_mask,
)
from ..modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ..pytorch_utils import is_torch_greater_or_equal, is_torch_greater_or_equal_than_2_3
from ..pytorch_utils import (
is_torch_greater_or_equal,
is_torch_greater_or_equal_than_2_3,
is_torch_greater_or_equal_than_2_6,
)


# Add this to src/transformers/integrations/executorch.py
Expand Down Expand Up @@ -824,6 +835,8 @@ def __init__(self, model, max_static_cache_length, batch_size):
self.static_cache.early_initialization(batch_size, num_heads, head_dim, torch.float32, "cpu")
self.cache = EncoderDecoderCache(self.static_cache, DynamicCache())

register_dynamic_cache_export_support()

# Register cache buffers to make them exportable
for i in range(len(self.static_cache)):
self.register_buffer(f"key_cache_{i}", self.static_cache.layers[i].keys, persistent=False)
Expand Down Expand Up @@ -996,6 +1009,8 @@ def export_with_dynamic_cache(
ALL_ATTENTION_FUNCTIONS.register("sdpa_without_vmap", ALL_ATTENTION_FUNCTIONS["sdpa"])
model.config._attn_implementation = "sdpa_without_vmap"

register_dynamic_cache_export_support()

with torch.no_grad():
exported_program = torch.export.export(
model,
Expand All @@ -1011,6 +1026,59 @@ def export_with_dynamic_cache(
return exported_program


def register_dynamic_cache_export_support():
"""
Utilities for `DynamicCache` <> torch.export support
"""

try:
torch.utils._pytree.register_pytree_node(
DynamicCache,
lambda dynamic_cache: torch.utils._pytree._dict_flatten(_get_cache_dict(dynamic_cache)),
_unflatten_dynamic_cache,
serialized_type_name=f"{DynamicCache.__module__}.{DynamicCache.__name__}",
flatten_with_keys_fn=lambda dynamic_cache: torch.utils._pytree._dict_flatten_with_keys(
_get_cache_dict(dynamic_cache)
),
)
# TODO (tmanlaibaatar) This won't be needed in torch 2.7.
torch.fx._pytree.register_pytree_flatten_spec(
DynamicCache,
lambda cache, spec: torch.fx._pytree._dict_flatten_spec(_get_cache_dict(cache), spec),
)
# Catching this in case there are multiple runs for some test runs
except ValueError as e:
if "already registered as pytree node" not in str(e):
raise


def _get_cache_dict(cache: DynamicCache):
"""Convert cache to dictionary format for pytree operations."""
if any(not isinstance(layer, (DynamicLayer, DynamicSlidingWindowLayer)) for layer in cache.layers):
raise RuntimeError("This pytree flattening function should only be applied to DynamicCache")

if not is_torch_greater_or_equal_than_2_6:
logging.warning("DynamicCache + torch.export is tested on torch 2.6.0+ and may not work on earlier versions.")

return {
"key_cache": [layer.keys for layer in cache.layers if layer.keys is not None],
"value_cache": [layer.values for layer in cache.layers if layer.values is not None],
}


def _unflatten_dynamic_cache(values, context: torch.utils._pytree.Context):
dictionary = torch.utils._pytree._dict_unflatten(values, context)
cache = DynamicCache()
# Reconstruct layers from keys and values lists
key_list = dictionary.get("key_cache", [])
value_list = dictionary.get("value_cache", [])
for idx in range(max(len(key_list), len(value_list))):
key = key_list[idx] if idx < len(key_list) else None
value = value_list[idx] if idx < len(value_list) else None
cache.update(key, value, idx)
return cache


def sdpa_mask_without_vmap(
batch_size: int,
cache_position: torch.Tensor,
Expand Down
17 changes: 16 additions & 1 deletion src/transformers/integrations/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.
from __future__ import annotations

import os
from typing import TYPE_CHECKING

from ..utils import is_torch_available
from ..utils import is_torch_available, strtobool


if TYPE_CHECKING:
Expand All @@ -36,3 +37,17 @@ def is_fsdp_managed_module(module: nn.Module) -> bool:
return isinstance(module, torch.distributed.fsdp.FullyShardedDataParallel) or getattr(
module, "_is_fsdp_managed_module", False
)


def is_fsdp_enabled():
if is_torch_available():
import torch

return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1
)

return False
12 changes: 1 addition & 11 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from .distributed import DistributedConfig
from .dynamic_module_utils import custom_object_save
from .generation import CompileConfig, GenerationConfig
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled
from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled, is_fsdp_enabled
from .integrations.accelerate import find_tied_parameters, init_empty_weights
from .integrations.deepspeed import _load_state_dict_into_zero3_model
from .integrations.eager_paged import eager_paged_attention_forward
Expand Down Expand Up @@ -124,7 +124,6 @@
is_torch_xla_available,
is_torch_xpu_available,
logging,
strtobool,
)
from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
Expand Down Expand Up @@ -182,15 +181,6 @@
from torch.distributed.tensor import DTensor


def is_fsdp_enabled():
return (
torch.distributed.is_available()
and torch.distributed.is_initialized()
and strtobool(os.environ.get("ACCELERATE_USE_FSDP", "False")) == 1
and strtobool(os.environ.get("FSDP_CPU_RAM_EFFICIENT_LOADING", "False")) == 1
)


def is_local_dist_rank_0():
return (
torch.distributed.is_available()
Expand Down