Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion litert_torch/generative/export_hf/core/export_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import gc
import json
import os
import time

import huggingface_hub
from litert_torch import fx_infra
Expand All @@ -36,6 +35,7 @@
from litert_torch.generative.export_hf.core.split_cache import attention as _
from litert_torch.generative.export_hf.core.split_cache import exportable_module as split_cache_module
from litert_torch.generative.export_hf.model_ext import exportables as model_ext_exportables
from litert_torch.generative.export_hf.model_ext import extension as model_ext_extension
from litert_torch.generative.export_hf.model_ext import patches as model_ext_patches
from litert_torch.generative.tools import tokenizer_to_sentencepiece_lib as tokenizer_lib
import torch
Expand Down Expand Up @@ -189,6 +189,16 @@ def load_model(
)


def update_export_config(
export_config: exportable_module.ExportableModuleConfig,
source_model_artifacts: SourceModelArtifacts,
) -> exportable_module.ExportableModuleConfig:
"""Updates export config."""
return model_ext_extension.update_export_config(
export_config, source_model_artifacts.model_config
)


def get_prefill_decode_exportable_cls(
model_config: transformers.PretrainedConfig,
export_config: exportable_module.ExportableModuleConfig,
Expand Down
3 changes: 3 additions & 0 deletions litert_torch/generative/export_hf/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@ def run_export_tasks(
auto_model_override=auto_model_override,
task=task,
)
export_config = export_lib.update_export_config(
export_config, source_model_artifacts
)
exported_model_artifacts = export_lib.ExportedModelArtifacts()

# Suppress deprecation warnings to be compatible with older PyTorch.
Expand Down
2 changes: 2 additions & 0 deletions litert_torch/generative/export_hf/model_ext/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,5 @@

from litert_torch.generative.export_hf.model_ext.gemma3 import patch as _
from litert_torch.generative.export_hf.model_ext.gemma3n import patch as _
from litert_torch.generative.export_hf.model_ext.lfm2 import cache as _
from litert_torch.generative.export_hf.model_ext.lfm2 import patch as _
36 changes: 36 additions & 0 deletions litert_torch/generative/export_hf/model_ext/extension.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright 2026 The LiteRT Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Extension for HF integration."""

import dataclasses

from litert_torch.generative.export_hf.core import exportable_module
import transformers


def update_export_config(
export_config: exportable_module.ExportableModuleConfig,
model_config: transformers.PretrainedConfig,
) -> exportable_module.ExportableModuleConfig:
"""Updates export config."""
match model_config.model_type:
case 'lfm2':
if export_config.split_cache:
raise ValueError('Split cache is not supported for LFM2.')
return dataclasses.replace(
export_config, cache_implementation='LiteRTLFM2Cache'
)
case _:
return export_config
187 changes: 187 additions & 0 deletions litert_torch/generative/export_hf/model_ext/lfm2/cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Copyright 2026 The LiteRT Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Cache for LFM2."""

from typing import List, Tuple
from litert_torch.generative.export_hf.core import cache as cache_lib
from litert_torch.generative.export_hf.core import cache_base as cache_base_lib
from litert_torch.generative.export_hf.core import exportable_module_config
import torch
import torch.utils._pytree as pytree


class LiteRTLFM2CacheLayer(cache_lib.LiteRTLMCacheLayer):
"""Optimized Cache layer class for LFM2 integration."""

def __init__(
self,
conv_state: torch.Tensor,
key_cache: cache_lib.KeyCache | None = None,
value_cache: cache_lib.ValueCache | None = None,
batch_size: int = 1,
k_ts_idx: int = 2,
v_ts_idx: int = 3,
**kwargs,
):
dummy_key_cache = torch.zeros((1, 1, 1, 1))
dummy_value_cache = torch.zeros((1, 1, 1, 1))
super().__init__(
dummy_key_cache,
dummy_value_cache,
batch_size,
k_ts_idx,
v_ts_idx,
**kwargs,
)
self.conv_state = conv_state

@classmethod
def create_from_config(
cls,
model_config,
layer_index,
export_config: exportable_module_config.ExportableModuleConfig,
**kwargs,
) -> "LiteRTLFM2CacheLayer":
"""Creates a KV cache from the model config."""
assert model_config.layer_types[layer_index] == "conv"
c_state_shape = (
export_config.batch_size,
model_config.hidden_size,
model_config.conv_L_cache - 1,
)
c_state = torch.zeros(c_state_shape, dtype=torch.float32)
return cls(
c_state,
batch_size=export_config.batch_size,
**kwargs,
)


@cache_base_lib.register_cache_implementation
class LiteRTLFM2Cache(cache_lib.LiteRTLMCache):
"""Optimized Cache class for LFM2 integration."""

@classmethod
def create_from_config(
cls,
model_config,
export_config: exportable_module_config.ExportableModuleConfig,
**kwargs,
) -> "LiteRTLFM2Cache":
"""Creates a KV cache from the model config."""
num_layers = model_config.num_hidden_layers
layers = []
for layer_index in range(num_layers):
if model_config.layer_types[layer_index] == "conv":
layers.append(
LiteRTLFM2CacheLayer.create_from_config(
model_config,
layer_index,
export_config,
)
)
else:
layers.append(
cache_lib.LiteRTLMCacheLayer.create_from_config(
model_config,
layer_index,
export_config,
)
)
return cls(layers)


def _flatten_kvc_t(
kvc: LiteRTLFM2Cache,
) -> Tuple[
List[torch.Tensor], Tuple[List[str], Tuple[int, int, int, int, List[bool]]]
]:
"""Flattens the cache into a list of tensors."""
flattened = []
flat_names = []
num_layers = len(kvc.layers)
layer_0 = kvc.layers[0]
is_conv = []
assert isinstance(layer_0, cache_base_lib.LiteRTLMCacheLayerMixin)
batch_size = layer_0.get_batch_size()
k_ts_idx = layer_0.get_k_ts_idx()
v_ts_idx = layer_0.get_v_ts_idx()
for i, layer in enumerate(kvc.layers):
if isinstance(layer, LiteRTLFM2CacheLayer):
is_conv.append(True)
flattened.append(layer.conv_state)
flat_names.append(f"c_{i}")
else:
is_conv.append(False)
flattened.append(layer.keys)
flat_names.append(f"k_{i}")
flattened.append(layer.values)
flat_names.append(f"v_{i}")
return flattened, (
flat_names,
(batch_size, num_layers, k_ts_idx, v_ts_idx, is_conv),
)


def _unflatten_kvc_t(
values: List[torch.Tensor],
context: Tuple[List[str], Tuple[int, int, int, int, List[bool]]],
) -> LiteRTLFM2Cache:
"""Unflattens the cache from a list of tensors."""
flat_names = context[0]
batch_size, num_layers, k_ts_idx, v_ts_idx, is_conv = context[1]
layers = []
for i in range(num_layers):
if is_conv[i]:
c_cache_idx = flat_names.index(f"c_{i}")
layers.append(
LiteRTLFM2CacheLayer(
conv_state=values[c_cache_idx],
batch_size=batch_size,
)
)
else:
k_cache_idx = flat_names.index(f"k_{i}")
v_cache_idx = flat_names.index(f"v_{i}")
layers.append(
cache_lib.LiteRTLMCacheLayer(
key_cache=values[k_cache_idx],
value_cache=values[v_cache_idx],
batch_size=batch_size,
k_ts_idx=k_ts_idx,
v_ts_idx=v_ts_idx,
)
)
obj = LiteRTLFM2Cache(layers)
return obj


def _flatten_kvc_t_with_keys(
kvc: LiteRTLFM2Cache,
):
flattened, (flat_names, _) = _flatten_kvc_t(kvc)
return [
(pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
], flat_names


pytree.register_pytree_node(
LiteRTLFM2Cache,
_flatten_kvc_t,
_unflatten_kvc_t,
flatten_with_keys_fn=_flatten_kvc_t_with_keys,
serialized_type_name="",
)
33 changes: 33 additions & 0 deletions litert_torch/generative/export_hf/model_ext/lfm2/patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2026 The LiteRT Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Patch for LFM2."""

import contextlib
from litert_torch.generative.export_hf.model_ext import patches as patches_lib
from litert_torch.generative.export_hf.model_ext.lfm2 import short_conv as short_conv_lib
from transformers.models.lfm2 import modeling_lfm2


@patches_lib.register_patch(["lfm2"])
@contextlib.contextmanager
def lfm2_litert_patch():
print("LFM2 patch applied.")
original_short_conv = modeling_lfm2.Lfm2ShortConv
modeling_lfm2.Lfm2ShortConv = short_conv_lib.Lfm2ShortConv

try:
yield
finally:
modeling_lfm2.Lfm2ShortConv = original_short_conv
61 changes: 61 additions & 0 deletions litert_torch/generative/export_hf/model_ext/lfm2/short_conv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright 2026 The LiteRT Torch Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Short convolutions for LFM2."""

from typing import Optional
import torch
from transformers.models.lfm2 import modeling_lfm2


class Lfm2ShortConv(modeling_lfm2.Lfm2ShortConv):
"""Short convolutions for LFM2, suitable for LiteRT inference."""

def __init__(
self,
config: modeling_lfm2.Lfm2Config,
layer_idx: int,
):
super().__init__(config, layer_idx)
self.conv = torch.nn.Conv1d(
in_channels=config.hidden_size,
out_channels=config.hidden_size,
kernel_size=self.L_cache,
groups=config.hidden_size,
bias=self.bias,
padding=0, # Padding is done in forward as part of state management.
)

def forward(
self,
hidden_states: torch.Tensor,
past_key_values=None,
cache_position: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
):
x = modeling_lfm2.apply_mask_to_padding_states(
hidden_states, attention_mask
)
b, c, x_proj = self.in_proj(x).chunk(3, dim=-1)
conv_input = b * x_proj
conv_input_t = conv_input.transpose(1, 2)
state = past_key_values.layers[self.layer_idx].conv_state
padded_input = torch.cat([state, conv_input_t], dim=-1)
next_state = padded_input[:, :, -(self.L_cache - 1) :]
conv_out = self.conv(padded_input)
conv_out = conv_out.transpose(1, 2)
y = c * conv_out
y = self.out_proj(y)
past_key_values.layers[self.layer_idx].conv_state = next_state
return y
Loading