Skip to content

Commit 65c910c

Browse files
sirakiincopybara-github
authored andcommitted
No public description
PiperOrigin-RevId: 861510441
1 parent 857bb8d commit 65c910c

File tree

2 files changed

+25
-17
lines changed

2 files changed

+25
-17
lines changed

litert_torch/generative/export_hf/core/export_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def load_model(
115115

116116
# TODO(weiyiw): Refactor into a separate function.
117117
tokenizer = transformers.AutoTokenizer.from_pretrained(model_path)
118-
if not hasattr(tokenizer, 'chat_template'):
118+
if not hasattr(tokenizer, 'chat_template') or not tokenizer.chat_template:
119119
try:
120120
if utils.get_model_path_type(model_path) == 'repo_id':
121121
template_file = huggingface_hub.hf_hub_download(

litert_torch/generative/export_hf/core/litert_lm_builder.py

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
def parse_chat_template(tokenizer):
2828
"""Parses chat template."""
2929
if tokenizer.chat_template is None:
30-
return (None, None), (None, None), (None, None)
30+
return None
3131
try:
3232
messages = [
3333
{'role': 'system', 'content': _PH},
@@ -39,14 +39,21 @@ def parse_chat_template(tokenizer):
3939
add_generation_prompt=False,
4040
)
4141
sys_prompt_parts = sys_prompt.split(_PH)
42+
no_sys_prompt = False
43+
if len(sys_prompt_parts) == 1:
44+
sys_prompt_parts = [sys_prompt_parts[0], '']
45+
no_sys_prompt = True
4246
if len(sys_prompt_parts) != 2:
4347
raise ValueError(
4448
f'System prompt {_PH} not found in chat template: {sys_prompt}'
4549
)
4650
if sys_prompt_parts[0].startswith(str(tokenizer.bos_token)):
4751
sys_prompt_parts[0] = sys_prompt_parts[0][len(tokenizer.bos_token) :]
4852

49-
messages.append({'role': 'user', 'content': _PH})
53+
if no_sys_prompt:
54+
messages = [{'role': 'user', 'content': _PH}]
55+
else:
56+
messages.append({'role': 'user', 'content': _PH})
5057
user_prompt = tokenizer.apply_chat_template(
5158
messages,
5259
tokenize=False,
@@ -133,20 +140,21 @@ def build_llm_metadata(
133140
if gen_config.temperature:
134141
sampler_params.temperature = gen_config.temperature
135142

136-
if isinstance(chat_templates, str):
137-
llm_metadata.jinja_prompt_template = chat_templates
138-
else:
139-
sys_prompt_parts, user_prompt_parts, model_prompt_parts = chat_templates
140-
pairs = []
141-
if sys_prompt_parts[0] is not None:
142-
pairs.append((sys_prompt_parts, llm_metadata.prompt_templates.system))
143-
if user_prompt_parts[0] is not None:
144-
pairs.append((user_prompt_parts, llm_metadata.prompt_templates.user))
145-
if model_prompt_parts[0] is not None:
146-
pairs.append((model_prompt_parts, llm_metadata.prompt_templates.model))
147-
for pts, fld in pairs:
148-
fld.prefix = pts[0]
149-
fld.suffix = pts[1]
143+
if chat_templates is not None:
144+
if isinstance(chat_templates, str):
145+
llm_metadata.jinja_prompt_template = chat_templates
146+
else:
147+
sys_prompt_parts, user_prompt_parts, model_prompt_parts = chat_templates
148+
pairs = []
149+
if sys_prompt_parts[0] is not None:
150+
pairs.append((sys_prompt_parts, llm_metadata.prompt_templates.system))
151+
if user_prompt_parts[0] is not None:
152+
pairs.append((user_prompt_parts, llm_metadata.prompt_templates.user))
153+
if model_prompt_parts[0] is not None:
154+
pairs.append((model_prompt_parts, llm_metadata.prompt_templates.model))
155+
for pts, fld in pairs:
156+
fld.prefix = pts[0]
157+
fld.suffix = pts[1]
150158

151159
llm_metadata.max_num_tokens = context_length
152160

0 commit comments

Comments
 (0)