Skip to content

Commit e85fbbe

Browse files
sirakiincopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 875911043
1 parent 48f11a1 commit e85fbbe

File tree

7 files changed

+333
-1
lines changed

7 files changed

+333
-1
lines changed

litert_torch/generative/export_hf/core/export_lib.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import gc
1919
import json
2020
import os
21-
import time
2221

2322
import huggingface_hub
2423
from litert_torch import fx_infra
@@ -36,6 +35,7 @@
3635
from litert_torch.generative.export_hf.core.split_cache import attention as _
3736
from litert_torch.generative.export_hf.core.split_cache import exportable_module as split_cache_module
3837
from litert_torch.generative.export_hf.model_ext import exportables as model_ext_exportables
38+
from litert_torch.generative.export_hf.model_ext import extension as model_ext_extension
3939
from litert_torch.generative.export_hf.model_ext import patches as model_ext_patches
4040
from litert_torch.generative.tools import tokenizer_to_sentencepiece_lib as tokenizer_lib
4141
import torch
@@ -189,6 +189,16 @@ def load_model(
189189
)
190190

191191

192+
def update_export_config(
193+
export_config: exportable_module.ExportableModuleConfig,
194+
source_model_artifacts: SourceModelArtifacts,
195+
) -> exportable_module.ExportableModuleConfig:
196+
"""Updates export config."""
197+
return model_ext_extension.update_export_config(
198+
export_config, source_model_artifacts.model_config
199+
)
200+
201+
192202
def get_prefill_decode_exportable_cls(
193203
model_config: transformers.PretrainedConfig,
194204
export_config: exportable_module.ExportableModuleConfig,

litert_torch/generative/export_hf/export.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,9 @@ def run_export_tasks(
4242
auto_model_override=auto_model_override,
4343
task=task,
4444
)
45+
export_config = export_lib.update_export_config(
46+
export_config, source_model_artifacts
47+
)
4548
exported_model_artifacts = export_lib.ExportedModelArtifacts()
4649

4750
# Suppress deprecation warnings to be compatible with older PyTorch.

litert_torch/generative/export_hf/model_ext/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,5 @@
1616

1717
from litert_torch.generative.export_hf.model_ext.gemma3 import patch as _
1818
from litert_torch.generative.export_hf.model_ext.gemma3n import patch as _
19+
from litert_torch.generative.export_hf.model_ext.lfm2 import cache as _
20+
from litert_torch.generative.export_hf.model_ext.lfm2 import patch as _
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
# Copyright 2026 The LiteRT Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Extension for HF integration."""
16+
17+
import dataclasses
18+
19+
from litert_torch.generative.export_hf.core import exportable_module
20+
import transformers
21+
22+
23+
def update_export_config(
24+
export_config: exportable_module.ExportableModuleConfig,
25+
model_config: transformers.PretrainedConfig,
26+
) -> exportable_module.ExportableModuleConfig:
27+
"""Updates export config."""
28+
match model_config.model_type:
29+
case 'lfm2':
30+
if export_config.split_cache:
31+
raise ValueError('Split cache is not supported for LFM2.')
32+
return dataclasses.replace(
33+
export_config, cache_implementation='LiteRTLFM2Cache'
34+
)
35+
case _:
36+
return export_config
Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
# Copyright 2026 The LiteRT Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Cache for LFM2."""
16+
17+
from typing import List, Tuple
18+
from litert_torch.generative.export_hf.core import cache as cache_lib
19+
from litert_torch.generative.export_hf.core import cache_base as cache_base_lib
20+
from litert_torch.generative.export_hf.core import exportable_module_config
21+
import torch
22+
import torch.utils._pytree as pytree
23+
24+
25+
class LiteRTLFM2CacheLayer(cache_lib.LiteRTLMCacheLayer):
26+
"""Optimized Cache layer class for LFM2 integration."""
27+
28+
def __init__(
29+
self,
30+
conv_state: torch.Tensor,
31+
key_cache: cache_lib.KeyCache | None = None,
32+
value_cache: cache_lib.ValueCache | None = None,
33+
batch_size: int = 1,
34+
k_ts_idx: int = 2,
35+
v_ts_idx: int = 3,
36+
**kwargs,
37+
):
38+
dummy_key_cache = torch.zeros((1, 1, 1, 1))
39+
dummy_value_cache = torch.zeros((1, 1, 1, 1))
40+
super().__init__(
41+
dummy_key_cache,
42+
dummy_value_cache,
43+
batch_size,
44+
k_ts_idx,
45+
v_ts_idx,
46+
**kwargs,
47+
)
48+
self.conv_state = conv_state
49+
50+
@classmethod
51+
def create_from_config(
52+
cls,
53+
model_config,
54+
layer_index,
55+
export_config: exportable_module_config.ExportableModuleConfig,
56+
**kwargs,
57+
) -> "LiteRTLFM2CacheLayer":
58+
"""Creates a KV cache from the model config."""
59+
assert model_config.layer_types[layer_index] == "conv"
60+
c_state_shape = (
61+
export_config.batch_size,
62+
model_config.hidden_size,
63+
model_config.conv_L_cache - 1,
64+
)
65+
c_state = torch.zeros(c_state_shape, dtype=torch.float32)
66+
return cls(
67+
c_state,
68+
batch_size=export_config.batch_size,
69+
**kwargs,
70+
)
71+
72+
73+
@cache_base_lib.register_cache_implementation
74+
class LiteRTLFM2Cache(cache_lib.LiteRTLMCache):
75+
"""Optimized Cache class for LFM2 integration."""
76+
77+
@classmethod
78+
def create_from_config(
79+
cls,
80+
model_config,
81+
export_config: exportable_module_config.ExportableModuleConfig,
82+
**kwargs,
83+
) -> "LiteRTLFM2Cache":
84+
"""Creates a KV cache from the model config."""
85+
num_layers = model_config.num_hidden_layers
86+
layers = []
87+
for layer_index in range(num_layers):
88+
if model_config.layer_types[layer_index] == "conv":
89+
layers.append(
90+
LiteRTLFM2CacheLayer.create_from_config(
91+
model_config,
92+
layer_index,
93+
export_config,
94+
)
95+
)
96+
else:
97+
layers.append(
98+
cache_lib.LiteRTLMCacheLayer.create_from_config(
99+
model_config,
100+
layer_index,
101+
export_config,
102+
)
103+
)
104+
return cls(layers)
105+
106+
107+
def _flatten_kvc_t(
108+
kvc: LiteRTLFM2Cache,
109+
) -> Tuple[
110+
List[torch.Tensor], Tuple[List[str], Tuple[int, int, int, int, List[bool]]]
111+
]:
112+
"""Flattens the cache into a list of tensors."""
113+
flattened = []
114+
flat_names = []
115+
num_layers = len(kvc.layers)
116+
layer_0 = kvc.layers[0]
117+
is_conv = []
118+
assert isinstance(layer_0, cache_base_lib.LiteRTLMCacheLayerMixin)
119+
batch_size = layer_0.get_batch_size()
120+
k_ts_idx = layer_0.get_k_ts_idx()
121+
v_ts_idx = layer_0.get_v_ts_idx()
122+
for i, layer in enumerate(kvc.layers):
123+
if isinstance(layer, LiteRTLFM2CacheLayer):
124+
is_conv.append(True)
125+
flattened.append(layer.conv_state)
126+
flat_names.append(f"c_{i}")
127+
else:
128+
is_conv.append(False)
129+
flattened.append(layer.keys)
130+
flat_names.append(f"k_{i}")
131+
flattened.append(layer.values)
132+
flat_names.append(f"v_{i}")
133+
return flattened, (
134+
flat_names,
135+
(batch_size, num_layers, k_ts_idx, v_ts_idx, is_conv),
136+
)
137+
138+
139+
def _unflatten_kvc_t(
140+
values: List[torch.Tensor],
141+
context: Tuple[List[str], Tuple[int, int, int, int, List[bool]]],
142+
) -> LiteRTLFM2Cache:
143+
"""Unflattens the cache from a list of tensors."""
144+
flat_names = context[0]
145+
batch_size, num_layers, k_ts_idx, v_ts_idx, is_conv = context[1]
146+
layers = []
147+
for i in range(num_layers):
148+
if is_conv[i]:
149+
c_cache_idx = flat_names.index(f"c_{i}")
150+
layers.append(
151+
LiteRTLFM2CacheLayer(
152+
conv_state=values[c_cache_idx],
153+
batch_size=batch_size,
154+
)
155+
)
156+
else:
157+
k_cache_idx = flat_names.index(f"k_{i}")
158+
v_cache_idx = flat_names.index(f"v_{i}")
159+
layers.append(
160+
cache_lib.LiteRTLMCacheLayer(
161+
key_cache=values[k_cache_idx],
162+
value_cache=values[v_cache_idx],
163+
batch_size=batch_size,
164+
k_ts_idx=k_ts_idx,
165+
v_ts_idx=v_ts_idx,
166+
)
167+
)
168+
obj = LiteRTLFM2Cache(layers)
169+
return obj
170+
171+
172+
def _flatten_kvc_t_with_keys(
173+
kvc: LiteRTLFM2Cache,
174+
):
175+
flattened, (flat_names, _) = _flatten_kvc_t(kvc)
176+
return [
177+
(pytree.MappingKey(k), v) for k, v in zip(flat_names, flattened)
178+
], flat_names
179+
180+
181+
pytree.register_pytree_node(
182+
LiteRTLFM2Cache,
183+
_flatten_kvc_t,
184+
_unflatten_kvc_t,
185+
flatten_with_keys_fn=_flatten_kvc_t_with_keys,
186+
serialized_type_name="",
187+
)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright 2026 The LiteRT Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Patch for LFM2."""
16+
17+
import contextlib
18+
from litert_torch.generative.export_hf.model_ext import patches as patches_lib
19+
from litert_torch.generative.export_hf.model_ext.lfm2 import short_conv as short_conv_lib
20+
from transformers.models.lfm2 import modeling_lfm2
21+
22+
23+
@patches_lib.register_patch(["lfm2"])
24+
@contextlib.contextmanager
25+
def lfm2_litert_patch():
26+
print("LFM2 patch applied.")
27+
original_short_conv = modeling_lfm2.Lfm2ShortConv
28+
modeling_lfm2.Lfm2ShortConv = short_conv_lib.Lfm2ShortConv
29+
30+
try:
31+
yield
32+
finally:
33+
modeling_lfm2.Lfm2ShortConv = original_short_conv
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2026 The LiteRT Torch Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
"""Short convolutions for LFM2."""
16+
17+
from typing import Optional
18+
import torch
19+
from transformers.models.lfm2 import modeling_lfm2
20+
21+
22+
class Lfm2ShortConv(modeling_lfm2.Lfm2ShortConv):
23+
"""Short convolutions for LFM2, suitable for LiteRT inference."""
24+
25+
def __init__(
26+
self,
27+
config: modeling_lfm2.Lfm2Config,
28+
layer_idx: int,
29+
):
30+
super().__init__(config, layer_idx)
31+
self.conv = torch.nn.Conv1d(
32+
in_channels=config.hidden_size,
33+
out_channels=config.hidden_size,
34+
kernel_size=self.L_cache,
35+
groups=config.hidden_size,
36+
bias=self.bias,
37+
padding=0, # Padding is done in forward as part of state management.
38+
)
39+
40+
def forward(
41+
self,
42+
hidden_states: torch.Tensor,
43+
past_key_values=None,
44+
cache_position: Optional[torch.LongTensor] = None,
45+
attention_mask: Optional[torch.Tensor] = None,
46+
):
47+
x = modeling_lfm2.apply_mask_to_padding_states(
48+
hidden_states, attention_mask
49+
)
50+
b, c, x_proj = self.in_proj(x).chunk(3, dim=-1)
51+
conv_input = b * x_proj
52+
conv_input_t = conv_input.transpose(1, 2)
53+
state = past_key_values.layers[self.layer_idx].conv_state
54+
padded_input = torch.cat([state, conv_input_t], dim=-1)
55+
next_state = padded_input[:, :, -(self.L_cache - 1) :]
56+
conv_out = self.conv(padded_input)
57+
conv_out = conv_out.transpose(1, 2)
58+
y = c * conv_out
59+
y = self.out_proj(y)
60+
past_key_values.layers[self.layer_idx].conv_state = next_state
61+
return y

0 commit comments

Comments
 (0)