Skip to content

Commit 1c56a0d

Browse files
authored
[Fix] Watsonx - Apply correct prompt templates for openai/gpt-oss model family (#15341)
* fix: apply_prompt_template * Revert "fix: apply_prompt_template" This reverts commit 3e0e40b. * add apply_prompt_template for WatsonX * feat: add apply_prompt_template * test_watsonx_gpt_oss_prompt_transformation * Revert "add apply_prompt_template for WatsonX" This reverts commit 3e80903. * add apply_prompt_template for WatsonX * fix apply_prompt_template * fix: add hf template handler * fix hf_chat_template * fix _get_tokenizer_config * fix hf_chat_template * add WatsonXModelPattern * fix aapply_prompt_template
1 parent 9d84a7c commit 1c56a0d

File tree

8 files changed

+692
-128
lines changed

8 files changed

+692
-128
lines changed

litellm/litellm_core_utils/prompt_templates/factory.py

Lines changed: 201 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import json
33
import mimetypes
44
import re
5-
from litellm._uuid import uuid
65
import xml.etree.ElementTree as ET
76
from enum import Enum
87
from typing import Any, List, Optional, Tuple, cast, overload
@@ -13,6 +12,7 @@
1312
import litellm.types
1413
import litellm.types.llms
1514
from litellm import verbose_logger
15+
from litellm._uuid import uuid
1616
from litellm.llms.custom_httpx.http_handler import HTTPHandler, get_async_httpx_client
1717
from litellm.types.files import get_file_extension_from_mime_type
1818
from litellm.types.llms.anthropic import *
@@ -364,62 +364,20 @@ def phind_codellama_pt(messages):
364364
return prompt
365365

366366

367-
def hf_chat_template( # noqa: PLR0915
368-
model: str, messages: list, chat_template: Optional[Any] = None
369-
):
370-
# Define Jinja2 environment
371-
env = ImmutableSandboxedEnvironment()
372-
373-
def raise_exception(message):
374-
raise Exception(f"Error message - {message}")
375-
376-
# Create a template object from the template text
377-
env.globals["raise_exception"] = raise_exception
378-
379-
## get the tokenizer config from huggingface
380-
bos_token = ""
381-
eos_token = ""
382-
if chat_template is None:
383-
384-
def _get_tokenizer_config(hf_model_name):
385-
try:
386-
url = f"https://huggingface.co/{hf_model_name}/raw/main/tokenizer_config.json"
387-
# Make a GET request to fetch the JSON data
388-
client = HTTPHandler(concurrent_limit=1)
389-
390-
response = client.get(url)
391-
except Exception as e:
392-
raise e
393-
if response.status_code == 200:
394-
# Parse the JSON data
395-
tokenizer_config = json.loads(response.content)
396-
return {"status": "success", "tokenizer": tokenizer_config}
397-
else:
398-
return {"status": "failure"}
399-
400-
if model in litellm.known_tokenizer_config:
401-
tokenizer_config = litellm.known_tokenizer_config[model]
402-
else:
403-
tokenizer_config = _get_tokenizer_config(model)
404-
litellm.known_tokenizer_config.update({model: tokenizer_config})
405-
406-
if (
407-
tokenizer_config["status"] == "failure"
408-
or "chat_template" not in tokenizer_config["tokenizer"]
409-
):
410-
raise Exception("No chat template found")
411-
## read the bos token, eos token and chat template from the json
412-
tokenizer_config = tokenizer_config["tokenizer"] # type: ignore
413-
414-
bos_token = tokenizer_config["bos_token"] # type: ignore
415-
if bos_token is not None and not isinstance(bos_token, str):
416-
if isinstance(bos_token, dict):
417-
bos_token = bos_token.get("content", None)
418-
eos_token = tokenizer_config["eos_token"] # type: ignore
419-
if eos_token is not None and not isinstance(eos_token, str):
420-
if isinstance(eos_token, dict):
421-
eos_token = eos_token.get("content", None)
422-
chat_template = tokenizer_config["chat_template"] # type: ignore
367+
def _render_chat_template(env, chat_template: str, bos_token: str, eos_token: str, messages: list) -> str:
368+
"""
369+
Shared template rendering logic for both sync and async hf_chat_template
370+
371+
Args:
372+
env: Jinja2 environment
373+
chat_template: Chat template string
374+
bos_token: Beginning of sequence token
375+
eos_token: End of sequence token
376+
messages: Messages to render
377+
378+
Returns:
379+
Rendered template string
380+
"""
423381
try:
424382
template = env.from_string(chat_template) # type: ignore
425383
except Exception as e:
@@ -434,7 +392,6 @@ def _is_system_in_template():
434392
bos_token="<bos>",
435393
)
436394
return True
437-
438395
# This will be raised if Jinja attempts to render the system message and it can't
439396
except Exception:
440397
return False
@@ -468,7 +425,7 @@ def _is_system_in_template():
468425
)
469426
except Exception as e:
470427
if "Conversation roles must alternate user/assistant" in str(e):
471-
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
428+
# reformat messages to ensure user/assistant are alternating
472429
new_messages = []
473430
for i in range(len(reformatted_messages) - 1):
474431
new_messages.append(reformatted_messages[i])
@@ -494,6 +451,188 @@ def _is_system_in_template():
494451
) # don't use verbose_logger.exception, if exception is raised
495452

496453

454+
async def _afetch_and_extract_template(
455+
model: str, chat_template: Optional[Any], get_config_fn, get_template_fn
456+
) -> Tuple[str, str, str]:
457+
"""
458+
Async version: Fetch template and tokens from HuggingFace.
459+
460+
Returns: (chat_template, bos_token, eos_token)
461+
"""
462+
from litellm.litellm_core_utils.prompt_templates.huggingface_template_handler import (
463+
_extract_token_value,
464+
)
465+
466+
bos_token = ""
467+
eos_token = ""
468+
469+
if chat_template is None:
470+
# Fetch or retrieve cached tokenizer config
471+
if model in litellm.known_tokenizer_config:
472+
tokenizer_config = litellm.known_tokenizer_config[model]
473+
else:
474+
tokenizer_config = await get_config_fn(hf_model_name=model)
475+
litellm.known_tokenizer_config.update({model: tokenizer_config})
476+
477+
# Try to get chat template from tokenizer_config.json first
478+
if (
479+
tokenizer_config.get("status") == "success"
480+
and "tokenizer" in tokenizer_config
481+
and isinstance(tokenizer_config["tokenizer"], dict)
482+
and "chat_template" in tokenizer_config["tokenizer"]
483+
):
484+
tokenizer_data: dict = tokenizer_config["tokenizer"] # type: ignore
485+
bos_token = _extract_token_value(
486+
token_value=tokenizer_data.get("bos_token")
487+
)
488+
eos_token = _extract_token_value(
489+
token_value=tokenizer_data.get("eos_token")
490+
)
491+
chat_template = tokenizer_data["chat_template"]
492+
else:
493+
# Fallback: Try to fetch chat template from separate .jinja file
494+
template_result = await get_template_fn(hf_model_name=model)
495+
if template_result.get("status") == "success":
496+
chat_template = template_result["chat_template"]
497+
# Still try to get tokens from tokenizer_config if available
498+
if (
499+
tokenizer_config.get("status") == "success"
500+
and "tokenizer" in tokenizer_config
501+
and isinstance(tokenizer_config["tokenizer"], dict)
502+
):
503+
tokenizer_data: dict = tokenizer_config["tokenizer"] # type: ignore
504+
bos_token = _extract_token_value(
505+
token_value=tokenizer_data.get("bos_token")
506+
)
507+
eos_token = _extract_token_value(
508+
token_value=tokenizer_data.get("eos_token")
509+
)
510+
else:
511+
raise Exception("No chat template found")
512+
513+
return chat_template, bos_token, eos_token # type: ignore
514+
515+
516+
def _fetch_and_extract_template(
517+
model: str, chat_template: Optional[Any], get_config_fn, get_template_fn
518+
) -> Tuple[str, str, str]:
519+
"""
520+
Sync version: Fetch template and tokens from HuggingFace.
521+
522+
Returns: (chat_template, bos_token, eos_token)
523+
"""
524+
from litellm.litellm_core_utils.prompt_templates.huggingface_template_handler import (
525+
_extract_token_value,
526+
)
527+
528+
bos_token = ""
529+
eos_token = ""
530+
531+
if chat_template is None:
532+
# Fetch or retrieve cached tokenizer config
533+
if model in litellm.known_tokenizer_config:
534+
tokenizer_config = litellm.known_tokenizer_config[model]
535+
else:
536+
tokenizer_config = get_config_fn(hf_model_name=model)
537+
litellm.known_tokenizer_config.update({model: tokenizer_config})
538+
539+
# Try to get chat template from tokenizer_config.json first
540+
if (
541+
tokenizer_config.get("status") == "success"
542+
and "tokenizer" in tokenizer_config
543+
and isinstance(tokenizer_config["tokenizer"], dict)
544+
and "chat_template" in tokenizer_config["tokenizer"]
545+
):
546+
tokenizer_data: dict = tokenizer_config["tokenizer"] # type: ignore
547+
bos_token = _extract_token_value(
548+
token_value=tokenizer_data.get("bos_token")
549+
)
550+
eos_token = _extract_token_value(
551+
token_value=tokenizer_data.get("eos_token")
552+
)
553+
chat_template = tokenizer_data["chat_template"]
554+
else:
555+
# Fallback: Try to fetch chat template from separate .jinja file
556+
template_result = get_template_fn(hf_model_name=model)
557+
if template_result.get("status") == "success":
558+
chat_template = template_result["chat_template"]
559+
# Still try to get tokens from tokenizer_config if available
560+
if (
561+
tokenizer_config.get("status") == "success"
562+
and "tokenizer" in tokenizer_config
563+
and isinstance(tokenizer_config["tokenizer"], dict)
564+
):
565+
tokenizer_data: dict = tokenizer_config["tokenizer"] # type: ignore
566+
bos_token = _extract_token_value(
567+
token_value=tokenizer_data.get("bos_token")
568+
)
569+
eos_token = _extract_token_value(
570+
token_value=tokenizer_data.get("eos_token")
571+
)
572+
else:
573+
raise Exception("No chat template found")
574+
575+
return chat_template, bos_token, eos_token # type: ignore
576+
577+
578+
async def ahf_chat_template(
579+
model: str, messages: list, chat_template: Optional[Any] = None
580+
):
581+
"""HuggingFace chat template (async version)"""
582+
from litellm.litellm_core_utils.prompt_templates.huggingface_template_handler import (
583+
_aget_chat_template_file,
584+
_aget_tokenizer_config,
585+
strftime_now,
586+
)
587+
588+
env = ImmutableSandboxedEnvironment()
589+
env.globals["raise_exception"] = lambda msg: Exception(f"Error message - {msg}")
590+
env.globals["strftime_now"] = strftime_now
591+
592+
template, bos_token, eos_token = await _afetch_and_extract_template(
593+
model=model,
594+
chat_template=chat_template,
595+
get_config_fn=_aget_tokenizer_config,
596+
get_template_fn=_aget_chat_template_file,
597+
)
598+
return _render_chat_template(
599+
env=env,
600+
chat_template=template,
601+
bos_token=bos_token,
602+
eos_token=eos_token,
603+
messages=messages,
604+
)
605+
606+
607+
def hf_chat_template(
608+
model: str, messages: list, chat_template: Optional[Any] = None
609+
):
610+
"""HuggingFace chat template (sync version)"""
611+
from litellm.litellm_core_utils.prompt_templates.huggingface_template_handler import (
612+
_get_chat_template_file,
613+
_get_tokenizer_config,
614+
strftime_now,
615+
)
616+
617+
env = ImmutableSandboxedEnvironment()
618+
env.globals["raise_exception"] = lambda msg: Exception(f"Error message - {msg}")
619+
env.globals["strftime_now"] = strftime_now
620+
621+
template, bos_token, eos_token = _fetch_and_extract_template(
622+
model=model,
623+
chat_template=chat_template,
624+
get_config_fn=_get_tokenizer_config,
625+
get_template_fn=_get_chat_template_file,
626+
)
627+
return _render_chat_template(
628+
env=env,
629+
chat_template=template,
630+
bos_token=bos_token,
631+
eos_token=eos_token,
632+
messages=messages,
633+
)
634+
635+
497636
def deepseek_r1_pt(messages):
498637
return hf_chat_template(
499638
model="deepseek-r1/deepseek-r1-7b-instruct", messages=messages
@@ -4031,33 +4170,9 @@ def prompt_factory(
40314170
elif custom_llm_provider == "azure_text":
40324171
return azure_text_pt(messages=messages)
40334172
elif custom_llm_provider == "watsonx":
4034-
if "granite" in model and "chat" in model:
4035-
# granite-13b-chat-v1 and granite-13b-chat-v2 use a specific prompt template
4036-
return ibm_granite_pt(messages=messages)
4037-
elif "ibm-mistral" in model and "instruct" in model:
4038-
# models like ibm-mistral/mixtral-8x7b-instruct-v01-q use the mistral instruct prompt template
4039-
return mistral_instruct_pt(messages=messages)
4040-
elif "meta-llama/llama-3" in model and "instruct" in model:
4041-
# https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
4042-
return custom_prompt(
4043-
role_dict={
4044-
"system": {
4045-
"pre_message": "<|start_header_id|>system<|end_header_id|>\n",
4046-
"post_message": "<|eot_id|>",
4047-
},
4048-
"user": {
4049-
"pre_message": "<|start_header_id|>user<|end_header_id|>\n",
4050-
"post_message": "<|eot_id|>",
4051-
},
4052-
"assistant": {
4053-
"pre_message": "<|start_header_id|>assistant<|end_header_id|>\n",
4054-
"post_message": "<|eot_id|>",
4055-
},
4056-
},
4057-
messages=messages,
4058-
initial_prompt_value="<|begin_of_text|>",
4059-
final_prompt_value="<|start_header_id|>assistant<|end_header_id|>\n",
4060-
)
4173+
from litellm.llms.watsonx.chat.transformation import IBMWatsonXChatConfig
4174+
return IBMWatsonXChatConfig.apply_prompt_template(model=model, messages=messages)
4175+
40614176
try:
40624177
if "meta-llama/llama-2" in model and "chat" in model:
40634178
return llama_2_chat_pt(messages=messages)

0 commit comments

Comments
 (0)