Skip to content

Commit 4c84cc3

Browse files
sirakiincopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 865674419
1 parent 7ec9a19 commit 4c84cc3

File tree

4 files changed

+45
-5
lines changed

4 files changed

+45
-5
lines changed

litert_torch/generative/export_hf/core/export_lib.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def load_model(
8080
model_path: str,
8181
trust_remote_code: bool = False,
8282
auto_model_override: str | None = None,
83+
task: str = 'text_generation',
8384
):
8485
"""Loads model from checkpoint."""
8586

@@ -90,7 +91,12 @@ def load_model(
9091
)
9192
config._attn_implementation = 'lrt_transposed_attention' # pylint: disable=protected-access
9293

93-
auto_model_cls = transformers.AutoModelForCausalLM
94+
if task == 'text_generation':
95+
auto_model_cls = transformers.AutoModelForCausalLM
96+
elif task == 'image_text_to_text':
97+
auto_model_cls = transformers.AutoModelForImageTextToText
98+
else:
99+
raise ValueError(f'Unsupported task: {task}')
94100
if auto_model_override is not None:
95101
auto_model_cls = transformers.__dict__[auto_model_override]
96102

@@ -101,14 +107,16 @@ def load_model(
101107
trust_remote_code=trust_remote_code,
102108
)
103109

104-
model.generation_config.cache_implementation = 'static'
105-
model.generation_config.do_sample = False
110+
if task == 'text_generation':
111+
model.generation_config.cache_implementation = 'static'
112+
model.generation_config.do_sample = False
106113

107114
text_model_config = config
108115
if hasattr(config, 'text_config'):
109116
text_model_config = config.text_config
110117

111-
verify_model_compatibility(model, config, text_model_config)
118+
if task == 'text_generation':
119+
verify_model_compatibility(model, config, text_model_config)
112120

113121
# TODO(weiyiw): Refactor into a separate function.
114122
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)

litert_torch/generative/export_hf/core/litert_lm_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,8 @@ def build_llm_metadata(
118118
if isinstance(gen_config.eos_token_id, int):
119119
stop_tokens.add(gen_config.eos_token_id)
120120
elif isinstance(gen_config.eos_token_id, list):
121-
stop_tokens.update(gen_config.eos_token_id)
121+
for token_id in gen_config.eos_token_id:
122+
stop_tokens.add(token_id)
122123
elif hasattr(tokenizer, 'eos_token') and tokenizer.eos_token:
123124
stop_tokens.add(tokenizer.eos_token)
124125
for stop_token in stop_tokens:

litert_torch/generative/export_hf/core/patches.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,3 +60,32 @@ def decorator(cls):
6060
transformers.integrations.use_kernel_forward_from_hub = (
6161
_use_kernel_forward_from_hub
6262
)
63+
64+
65+
# TODO(weiyiw): Find a better way to patch Gemma3RMSNorm.
66+
class Gemma3RMSNorm(torch.nn.Module):
67+
"""RMSNorm Layer."""
68+
69+
def __init__(self, dim: int, eps: float = 1e-6):
70+
"""RMSNorm Layer."""
71+
super().__init__()
72+
self.weight = torch.nn.Parameter(torch.ones(dim))
73+
self.variance_epsilon = eps
74+
self.hidden_size = dim
75+
76+
def forward(self, hidden_states):
77+
return normalization.rms_norm_with_hlfb(
78+
hidden_states,
79+
self.weight + 1.0,
80+
self.variance_epsilon,
81+
torch.ones((self.hidden_size,), dtype=torch.float32),
82+
)
83+
84+
def extra_repr(self):
85+
return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
86+
87+
88+
from transformers.models.gemma3 import modeling_gemma3
89+
90+
original_gemma3_rms_norm = modeling_gemma3.Gemma3RMSNorm
91+
modeling_gemma3.Gemma3RMSNorm = Gemma3RMSNorm

litert_torch/generative/export_hf/export.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ def export(
3939
# target_accelerator: str | None = None,
4040
trust_remote_code: bool = False,
4141
use_jinja_template: bool = False,
42+
task: str = 'text_generation',
4243
):
4344
"""Exports HuggingFace Transformers model to tflite."""
4445
# TODO(weiyiw): Use tmp dir for work_dir.
@@ -48,6 +49,7 @@ def export(
4849
model,
4950
trust_remote_code=trust_remote_code,
5051
auto_model_override=auto_model_override,
52+
task=task,
5153
)
5254
del config # Unused.
5355
if split_cache and not externalize_embedder:

0 commit comments

Comments
 (0)