Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
3a960bb
Bump transformers to 4.54.1
jackzhxng Aug 1, 2025
3d223a2
Bump torch
jackzhxng Aug 1, 2025
207f8b1
Fix no module found error for custom_kv_cache
jackzhxng Aug 4, 2025
bc82841
Try to fix Missing operator: [8] quantized_decomposed::embedding_byte…
jackzhxng Aug 4, 2025
35fc918
Fix quantization requires torchao >= 0.11.0
jackzhxng Aug 4, 2025
6a26464
Fix sliding window, print loaded ops
jackzhxng Aug 5, 2025
4d68263
Bump ET nightly pin, fixes missing quantized ops
jackzhxng Aug 5, 2025
6a3e1d4
Fix no Q_ANNOTATION_KEY
jackzhxng Aug 5, 2025
2b5fe7e
Try to fix segfault/bus error by holding onto temp dir
jackzhxng Aug 8, 2025
bb0089c
Bigger mac runners
jackzhxng Aug 9, 2025
72802e3
Revert "Bigger mac runners"
jackzhxng Aug 10, 2025
9876c7e
Add helpful logs
jackzhxng Aug 10, 2025
19f4d21
Re-enable smollm3 tests for linux
jackzhxng Aug 10, 2025
99805f8
Experiment reverting transformers bump
jackzhxng Aug 10, 2025
108ed17
Revert "Experiment reverting transformers bump"
jackzhxng Aug 13, 2025
59778eb
Formatting and remove logs
jackzhxng Aug 13, 2025
ff8a2a1
Bump ET release from 0.6 -> 0.7
jackzhxng Aug 14, 2025
a3009ca
Bisect down to ET 20250701
jackzhxng Aug 14, 2025
ae488b1
Experiment reverting transformers bump
jackzhxng Aug 10, 2025
b7a2fa1
Clean
jackzhxng Aug 14, 2025
1e0a671
Bisect down to ET 20250628
jackzhxng Aug 14, 2025
896f0da
Bisect down to ET 20250626
jackzhxng Aug 14, 2025
abd641b
Revert "Bisect down to ET 20250626"
jackzhxng Aug 15, 2025
7f7f9c2
Revert "Bisect down to ET 20250628"
jackzhxng Aug 15, 2025
5f8a56f
Revert "Experiment reverting transformers bump"
jackzhxng Aug 15, 2025
92bc2ba
Revert "Bisect down to ET 20250701"
jackzhxng Aug 15, 2025
4abb2ec
Skip mac tests
jackzhxng Aug 15, 2025
ad9b639
Remove unnecessary ET 0.6 guards
jackzhxng Aug 15, 2025
b252038
Ruff format
jackzhxng Aug 15, 2025
e135310
Remove all transformers < 4.54 guards
jackzhxng Aug 15, 2025
671bc06
Merge branch 'main' into jz/bump-transformers
jackzhxng Aug 15, 2025
70338e9
Format
jackzhxng Aug 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions install_dev.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,21 @@

def install_torch_nightly_deps():
"""Install torch related dependencies from pinned nightly"""
EXECUTORCH_NIGHTLY_VERSION = "dev20250625"
TORCHAO_NIGHTLY_VERSION = "dev20250620"
EXECUTORCH_NIGHTLY_VERSION = "dev20250730"
TORCHAO_NIGHTLY_VERSION = "dev20250730"
# Torch nightly is aligned with pinned nightly in https://github.com/pytorch/executorch/blob/main/install_requirements.py#L74
TORCH_NIGHTLY_VERSION = "dev20250601"
TORCH_NIGHTLY_VERSION = "dev20250725"
subprocess.check_call(
[
sys.executable,
"-m",
"pip",
"install",
f"executorch==0.7.0.{EXECUTORCH_NIGHTLY_VERSION}",
f"torch==2.8.0.{TORCH_NIGHTLY_VERSION}",
f"torchvision==0.23.0.{TORCH_NIGHTLY_VERSION}",
f"executorch==0.8.0.{EXECUTORCH_NIGHTLY_VERSION}",
f"torch==2.9.0.{TORCH_NIGHTLY_VERSION}",
f"torchvision==0.24.0.{TORCH_NIGHTLY_VERSION}",
f"torchaudio==2.8.0.{TORCH_NIGHTLY_VERSION}",
f"torchao==0.12.0.{TORCHAO_NIGHTLY_VERSION}",
f"torchao==0.13.0.{TORCHAO_NIGHTLY_VERSION}",
"--extra-index-url",
"https://download.pytorch.org/whl/nightly/cpu",
]
Expand All @@ -34,7 +34,7 @@ def install_dep_from_source():
"-m",
"pip",
"install",
"git+https://github.com/huggingface/transformers@896e9cea1ade521b2648f4798218550f6c72190c#egg=transformers", # 4.53.1
"git+https://github.com/huggingface/transformers@9c641dc16154964e5ffc0c13e9ec6aaffa295ed6#egg=transformers", # 4.54.1
]
)
subprocess.check_call(
Expand Down
67 changes: 32 additions & 35 deletions optimum/executorch/attentions/custom_kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,12 @@ def __init__(

# Create a list of CustomKVCache instances, one per layer
self.kv_cache = torch.nn.ModuleList()
for _ in range(config.num_hidden_layers):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happened here? like config doesnt exist anymore?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It still exists, feel like it's more idiomatic to iterate over the actual layers

for layer in self.layers:
layer_cache = CustomKVCache(
max_batch_size=self.max_batch_size,
max_context_length=self.max_cache_len,
n_heads=self.num_key_value_heads,
head_dim=self.head_dim,
max_batch_size=layer.max_batch_size,
max_context_length=layer.max_cache_len,
n_heads=layer.num_heads,
head_dim=layer.head_dim,
dtype=dtype,
)
self.kv_cache.append(layer_cache)
Expand Down Expand Up @@ -202,32 +202,29 @@ def __init__(
layer_device_map=layer_device_map,
)

# make sure layer_device_map is none
assert layer_device_map is None
assert device is None or device == "cpu", "Device must be None or 'cpu'"

self.cache_position = None
# Create a list of cache instances, one per layer
# Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers
# Create a list of cache instances, one per layer.
# Use CustomKVCache for global layers and CustomRingKVCache for sliding window layers.
self.kv_cache = torch.nn.ModuleList()
for layer_idx in range(config.num_hidden_layers):
# newer version of transfomer has is_sliding defined
# for HybridCache
if self.is_sliding[layer_idx]:
for layer in self.layers:
if layer.is_sliding:
# This is a sliding window layer
layer_cache = CustomRingKVCache(
max_batch_size=self.max_batch_size,
max_context_length=self.sliding_window_len,
n_heads=self.num_key_value_heads,
head_dim=self.head_dim,
max_batch_size=layer.max_batch_size,
max_context_length=layer.max_cache_len,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wait what is happening here? is this same as sliding_window_len

Copy link
Collaborator Author

@jackzhxng jackzhxng Aug 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

n_heads=layer.num_heads,
head_dim=layer.head_dim,
dtype=dtype,
)
else:
layer_cache = CustomKVCache(
max_batch_size=self.max_batch_size,
max_context_length=self.max_cache_len,
n_heads=self.num_key_value_heads,
head_dim=self.head_dim,
max_batch_size=layer.max_batch_size,
max_context_length=layer.max_cache_len,
n_heads=layer.num_heads,
head_dim=layer.head_dim,
dtype=dtype,
)
self.kv_cache.append(layer_cache)
Expand Down Expand Up @@ -284,7 +281,7 @@ def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:

# For CustomRingKVCache, we need to handle the sequence length differently
layer_cache = self.kv_cache[layer_idx]
if self.is_sliding[layer_idx]:
if self.layers[layer_idx].is_sliding:
# CustomRingKVCache cache_position_manager which
# maintains cache position for each slot in the kv cache
# we return the max position + 1 to indicate max position
Expand All @@ -308,7 +305,7 @@ def get_layer_cache(self, layer_idx: int):

def replace_with_et_custom_kv_cache(module, config, generation_config, cache_dtype):
"""
Replace all KV caches in the module with ETCustomStaticCache.
Replace all KV caches in the module with ETCustomStaticCache or ETCustomHybridCache.
This modifies the model in place.

Args:
Expand Down Expand Up @@ -342,18 +339,18 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
if getattr(module, "replace_cache", None) is not None:
static_cache = ETCustomStaticCache(
config=config,
max_batch_size=generation_config.cache_config.batch_size,
max_cache_len=generation_config.cache_config.max_cache_len,
device=generation_config.cache_config.device,
max_batch_size=generation_config.cache_config.get("batch_size"),
max_cache_len=generation_config.cache_config.get("max_cache_len"),
device=generation_config.cache_config.get("device"),
dtype=cache_dtype,
)
module.replace_cache(static_cache)
else:
module.static_cache = ETCustomStaticCache(
config=config,
max_batch_size=generation_config.cache_config.batch_size,
max_cache_len=generation_config.cache_config.max_cache_len,
device=generation_config.cache_config.device,
max_batch_size=generation_config.cache_config.get("batch_size"),
max_cache_len=generation_config.cache_config.get("max_cache_len"),
device=generation_config.cache_config.get("device"),
dtype=cache_dtype,
)
# Dont know why we need to this even though
Expand All @@ -370,25 +367,25 @@ def _replace_with_et_custom_kv_cache(module, config, generation_config, cache_dt
if getattr(module, "replace_cache", None) is not None:
hybrid_cache = ETCustomHybridCache(
config=config,
max_batch_size=generation_config.cache_config.batch_size,
max_cache_len=generation_config.cache_config.max_cache_len,
device=generation_config.cache_config.device,
max_batch_size=generation_config.cache_config.get("batch_size"),
max_cache_len=generation_config.cache_config.get("max_cache_len"),
device=generation_config.cache_config.get("device"),
dtype=cache_dtype,
)
module.replace_cache(hybrid_cache)
else:
module.cache = ETCustomHybridCache(
config=config,
max_batch_size=generation_config.cache_config.batch_size,
max_cache_len=generation_config.cache_config.max_cache_len,
device=generation_config.cache_config.device,
max_batch_size=generation_config.cache_config.get("batch_size"),
max_cache_len=generation_config.cache_config.get("max_cache_len"),
device=generation_config.cache_config.get("device"),
dtype=cache_dtype,
)
# Register cache attributes for each layer
for i in range(len(module.cache.kv_cache)):
setattr(module, f"key_cache_{i}", module.cache.kv_cache[i].k_cache)
setattr(module, f"value_cache_{i}", module.cache.kv_cache[i].v_cache)
if module.cache.is_sliding[i]:
if module.cache.layers[i].is_sliding:
# Register cache_positions as buffer for sliding window layers
# This prevents it from being traced as a constant
module.register_buffer(
Expand Down
8 changes: 8 additions & 0 deletions optimum/executorch/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa
from transformers import (
AutoModelForCausalLM,
AutoModelForImageClassification,
Expand Down Expand Up @@ -185,6 +186,13 @@ def _from_pretrained(
subfolder=subfolder,
local_files_only=local_files_only,
)

from torch.ao.quantization.fx._decomposed import quantized_decomposed_lib # noqa
from executorch.kernels import quantized # noqa
from executorch.extension.pybindings.portable_lib import _get_operator_names
print("----------- LOADED OPS ----------")
print('\n'.join(_get_operator_names()))
print("---------------------------------")
model = _load_for_executorch(model_cache_path)
logging.info(
f"Loaded model from {model_cache_path} ({os.path.getsize(model_cache_path) / (1024 * 1024):.2f} MB)"
Expand Down
12 changes: 8 additions & 4 deletions optimum/exporters/executorch/integrations.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from transformers.generation.configuration_utils import GenerationConfig

from executorch import version as executorch_version
from optimum.executorch.attentions.custom_sdpa import get_custom_sdpa_for_ring_kv_cache
from optimum.utils.import_utils import is_transformers_version

Expand Down Expand Up @@ -89,7 +90,10 @@ def _prepare_export_inputs(self):
return example_input_ids, example_cache_position, dynamic_shapes, strict

def _register_attention_mask_for_4_53(self, exportable_module: torch.nn.Module):
if is_transformers_version(">=", "4.53.0.dev0"):
if (
is_transformers_version(">=", "4.53.0.dev0")
and parse(executorch_version.__version__).base_version > "0.6.0"
):
from transformers.integrations.executorch import sdpa_mask_without_vmap
from transformers.masking_utils import AttentionMaskInterface
from transformers.modeling_utils import AttentionInterface
Expand Down Expand Up @@ -126,7 +130,7 @@ def export(
)
self._register_attention_mask_for_4_53(exportable_module)

if self.use_custom_kv_cache:
if self.use_custom_kv_cache and parse(executorch_version.__version__).base_version > "0.6.0":
from optimum.executorch.attentions.custom_kv_cache import (
replace_with_et_custom_kv_cache,
)
Expand Down Expand Up @@ -395,8 +399,8 @@ def _export_decoder(self, decoder_input_ids, encoder_hidden_states, cache_positi
wrapped_decoder = (
Seq2SeqLMDecoderExportableModuleWithStaticCache(
model=self.full_model,
max_static_cache_length=self.generation_config.cache_config.max_cache_len,
batch_size=self.generation_config.cache_config.batch_size,
max_static_cache_length=self.generation_config.cache_config.get("max_cache_len"),
batch_size=self.generation_config.cache_config.get("batch_size"),
)
.to("cpu")
.eval()
Expand Down
4 changes: 0 additions & 4 deletions optimum/exporters/executorch/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,6 @@ def quantize_model_(
if not (qlinear_config or qembedding_config):
return

# TODO: Update torchao to use 0.11.0 once released
if parse(torchao.__version__) < parse("0.11.0.dev0"):
raise RuntimeError("Quantization requires torchao >= 0.11.0. Please upgrade torchao.")

from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.quant_api import (
Int8DynamicActivationIntxWeightConfig,
Expand Down
4 changes: 2 additions & 2 deletions optimum/exporters/executorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def save_config_to_constant_methods(
# Check for cache_config and its attributes
cache_config = getattr(generation_config, "cache_config", None)
if cache_config is not None:
max_batch_size = getattr(cache_config, "batch_size", None)
max_seq_len = getattr(cache_config, "max_cache_len", None)
max_batch_size = cache_config.get("batch_size")
max_seq_len = cache_config.get("max_cache_len")

if max_batch_size is not None:
metadata["get_max_batch_size"] = max_batch_size
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
INSTALL_REQUIRE = [
"optimum~=1.24",
"executorch>=0.6.0",
"transformers==4.51.3",
"transformers==4.54.1",
]

TESTS_REQUIRE = [
Expand Down
Loading