2
2
import json
3
3
import mimetypes
4
4
import re
5
- from litellm ._uuid import uuid
6
5
import xml .etree .ElementTree as ET
7
6
from enum import Enum
8
7
from typing import Any , List , Optional , Tuple , cast , overload
13
12
import litellm .types
14
13
import litellm .types .llms
15
14
from litellm import verbose_logger
15
+ from litellm ._uuid import uuid
16
16
from litellm .llms .custom_httpx .http_handler import HTTPHandler , get_async_httpx_client
17
17
from litellm .types .files import get_file_extension_from_mime_type
18
18
from litellm .types .llms .anthropic import *
@@ -364,62 +364,20 @@ def phind_codellama_pt(messages):
364
364
return prompt
365
365
366
366
367
- def hf_chat_template ( # noqa: PLR0915
368
- model : str , messages : list , chat_template : Optional [Any ] = None
369
- ):
370
- # Define Jinja2 environment
371
- env = ImmutableSandboxedEnvironment ()
372
-
373
- def raise_exception (message ):
374
- raise Exception (f"Error message - { message } " )
375
-
376
- # Create a template object from the template text
377
- env .globals ["raise_exception" ] = raise_exception
378
-
379
- ## get the tokenizer config from huggingface
380
- bos_token = ""
381
- eos_token = ""
382
- if chat_template is None :
383
-
384
- def _get_tokenizer_config (hf_model_name ):
385
- try :
386
- url = f"https://huggingface.co/{ hf_model_name } /raw/main/tokenizer_config.json"
387
- # Make a GET request to fetch the JSON data
388
- client = HTTPHandler (concurrent_limit = 1 )
389
-
390
- response = client .get (url )
391
- except Exception as e :
392
- raise e
393
- if response .status_code == 200 :
394
- # Parse the JSON data
395
- tokenizer_config = json .loads (response .content )
396
- return {"status" : "success" , "tokenizer" : tokenizer_config }
397
- else :
398
- return {"status" : "failure" }
399
-
400
- if model in litellm .known_tokenizer_config :
401
- tokenizer_config = litellm .known_tokenizer_config [model ]
402
- else :
403
- tokenizer_config = _get_tokenizer_config (model )
404
- litellm .known_tokenizer_config .update ({model : tokenizer_config })
405
-
406
- if (
407
- tokenizer_config ["status" ] == "failure"
408
- or "chat_template" not in tokenizer_config ["tokenizer" ]
409
- ):
410
- raise Exception ("No chat template found" )
411
- ## read the bos token, eos token and chat template from the json
412
- tokenizer_config = tokenizer_config ["tokenizer" ] # type: ignore
413
-
414
- bos_token = tokenizer_config ["bos_token" ] # type: ignore
415
- if bos_token is not None and not isinstance (bos_token , str ):
416
- if isinstance (bos_token , dict ):
417
- bos_token = bos_token .get ("content" , None )
418
- eos_token = tokenizer_config ["eos_token" ] # type: ignore
419
- if eos_token is not None and not isinstance (eos_token , str ):
420
- if isinstance (eos_token , dict ):
421
- eos_token = eos_token .get ("content" , None )
422
- chat_template = tokenizer_config ["chat_template" ] # type: ignore
367
+ def _render_chat_template (env , chat_template : str , bos_token : str , eos_token : str , messages : list ) -> str :
368
+ """
369
+ Shared template rendering logic for both sync and async hf_chat_template
370
+
371
+ Args:
372
+ env: Jinja2 environment
373
+ chat_template: Chat template string
374
+ bos_token: Beginning of sequence token
375
+ eos_token: End of sequence token
376
+ messages: Messages to render
377
+
378
+ Returns:
379
+ Rendered template string
380
+ """
423
381
try :
424
382
template = env .from_string (chat_template ) # type: ignore
425
383
except Exception as e :
@@ -434,7 +392,6 @@ def _is_system_in_template():
434
392
bos_token = "<bos>" ,
435
393
)
436
394
return True
437
-
438
395
# This will be raised if Jinja attempts to render the system message and it can't
439
396
except Exception :
440
397
return False
@@ -468,7 +425,7 @@ def _is_system_in_template():
468
425
)
469
426
except Exception as e :
470
427
if "Conversation roles must alternate user/assistant" in str (e ):
471
- # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
428
+ # reformat messages to ensure user/assistant are alternating
472
429
new_messages = []
473
430
for i in range (len (reformatted_messages ) - 1 ):
474
431
new_messages .append (reformatted_messages [i ])
@@ -494,6 +451,188 @@ def _is_system_in_template():
494
451
) # don't use verbose_logger.exception, if exception is raised
495
452
496
453
454
+ async def _afetch_and_extract_template (
455
+ model : str , chat_template : Optional [Any ], get_config_fn , get_template_fn
456
+ ) -> Tuple [str , str , str ]:
457
+ """
458
+ Async version: Fetch template and tokens from HuggingFace.
459
+
460
+ Returns: (chat_template, bos_token, eos_token)
461
+ """
462
+ from litellm .litellm_core_utils .prompt_templates .huggingface_template_handler import (
463
+ _extract_token_value ,
464
+ )
465
+
466
+ bos_token = ""
467
+ eos_token = ""
468
+
469
+ if chat_template is None :
470
+ # Fetch or retrieve cached tokenizer config
471
+ if model in litellm .known_tokenizer_config :
472
+ tokenizer_config = litellm .known_tokenizer_config [model ]
473
+ else :
474
+ tokenizer_config = await get_config_fn (hf_model_name = model )
475
+ litellm .known_tokenizer_config .update ({model : tokenizer_config })
476
+
477
+ # Try to get chat template from tokenizer_config.json first
478
+ if (
479
+ tokenizer_config .get ("status" ) == "success"
480
+ and "tokenizer" in tokenizer_config
481
+ and isinstance (tokenizer_config ["tokenizer" ], dict )
482
+ and "chat_template" in tokenizer_config ["tokenizer" ]
483
+ ):
484
+ tokenizer_data : dict = tokenizer_config ["tokenizer" ] # type: ignore
485
+ bos_token = _extract_token_value (
486
+ token_value = tokenizer_data .get ("bos_token" )
487
+ )
488
+ eos_token = _extract_token_value (
489
+ token_value = tokenizer_data .get ("eos_token" )
490
+ )
491
+ chat_template = tokenizer_data ["chat_template" ]
492
+ else :
493
+ # Fallback: Try to fetch chat template from separate .jinja file
494
+ template_result = await get_template_fn (hf_model_name = model )
495
+ if template_result .get ("status" ) == "success" :
496
+ chat_template = template_result ["chat_template" ]
497
+ # Still try to get tokens from tokenizer_config if available
498
+ if (
499
+ tokenizer_config .get ("status" ) == "success"
500
+ and "tokenizer" in tokenizer_config
501
+ and isinstance (tokenizer_config ["tokenizer" ], dict )
502
+ ):
503
+ tokenizer_data : dict = tokenizer_config ["tokenizer" ] # type: ignore
504
+ bos_token = _extract_token_value (
505
+ token_value = tokenizer_data .get ("bos_token" )
506
+ )
507
+ eos_token = _extract_token_value (
508
+ token_value = tokenizer_data .get ("eos_token" )
509
+ )
510
+ else :
511
+ raise Exception ("No chat template found" )
512
+
513
+ return chat_template , bos_token , eos_token # type: ignore
514
+
515
+
516
+ def _fetch_and_extract_template (
517
+ model : str , chat_template : Optional [Any ], get_config_fn , get_template_fn
518
+ ) -> Tuple [str , str , str ]:
519
+ """
520
+ Sync version: Fetch template and tokens from HuggingFace.
521
+
522
+ Returns: (chat_template, bos_token, eos_token)
523
+ """
524
+ from litellm .litellm_core_utils .prompt_templates .huggingface_template_handler import (
525
+ _extract_token_value ,
526
+ )
527
+
528
+ bos_token = ""
529
+ eos_token = ""
530
+
531
+ if chat_template is None :
532
+ # Fetch or retrieve cached tokenizer config
533
+ if model in litellm .known_tokenizer_config :
534
+ tokenizer_config = litellm .known_tokenizer_config [model ]
535
+ else :
536
+ tokenizer_config = get_config_fn (hf_model_name = model )
537
+ litellm .known_tokenizer_config .update ({model : tokenizer_config })
538
+
539
+ # Try to get chat template from tokenizer_config.json first
540
+ if (
541
+ tokenizer_config .get ("status" ) == "success"
542
+ and "tokenizer" in tokenizer_config
543
+ and isinstance (tokenizer_config ["tokenizer" ], dict )
544
+ and "chat_template" in tokenizer_config ["tokenizer" ]
545
+ ):
546
+ tokenizer_data : dict = tokenizer_config ["tokenizer" ] # type: ignore
547
+ bos_token = _extract_token_value (
548
+ token_value = tokenizer_data .get ("bos_token" )
549
+ )
550
+ eos_token = _extract_token_value (
551
+ token_value = tokenizer_data .get ("eos_token" )
552
+ )
553
+ chat_template = tokenizer_data ["chat_template" ]
554
+ else :
555
+ # Fallback: Try to fetch chat template from separate .jinja file
556
+ template_result = get_template_fn (hf_model_name = model )
557
+ if template_result .get ("status" ) == "success" :
558
+ chat_template = template_result ["chat_template" ]
559
+ # Still try to get tokens from tokenizer_config if available
560
+ if (
561
+ tokenizer_config .get ("status" ) == "success"
562
+ and "tokenizer" in tokenizer_config
563
+ and isinstance (tokenizer_config ["tokenizer" ], dict )
564
+ ):
565
+ tokenizer_data : dict = tokenizer_config ["tokenizer" ] # type: ignore
566
+ bos_token = _extract_token_value (
567
+ token_value = tokenizer_data .get ("bos_token" )
568
+ )
569
+ eos_token = _extract_token_value (
570
+ token_value = tokenizer_data .get ("eos_token" )
571
+ )
572
+ else :
573
+ raise Exception ("No chat template found" )
574
+
575
+ return chat_template , bos_token , eos_token # type: ignore
576
+
577
+
578
+ async def ahf_chat_template (
579
+ model : str , messages : list , chat_template : Optional [Any ] = None
580
+ ):
581
+ """HuggingFace chat template (async version)"""
582
+ from litellm .litellm_core_utils .prompt_templates .huggingface_template_handler import (
583
+ _aget_chat_template_file ,
584
+ _aget_tokenizer_config ,
585
+ strftime_now ,
586
+ )
587
+
588
+ env = ImmutableSandboxedEnvironment ()
589
+ env .globals ["raise_exception" ] = lambda msg : Exception (f"Error message - { msg } " )
590
+ env .globals ["strftime_now" ] = strftime_now
591
+
592
+ template , bos_token , eos_token = await _afetch_and_extract_template (
593
+ model = model ,
594
+ chat_template = chat_template ,
595
+ get_config_fn = _aget_tokenizer_config ,
596
+ get_template_fn = _aget_chat_template_file ,
597
+ )
598
+ return _render_chat_template (
599
+ env = env ,
600
+ chat_template = template ,
601
+ bos_token = bos_token ,
602
+ eos_token = eos_token ,
603
+ messages = messages ,
604
+ )
605
+
606
+
607
+ def hf_chat_template (
608
+ model : str , messages : list , chat_template : Optional [Any ] = None
609
+ ):
610
+ """HuggingFace chat template (sync version)"""
611
+ from litellm .litellm_core_utils .prompt_templates .huggingface_template_handler import (
612
+ _get_chat_template_file ,
613
+ _get_tokenizer_config ,
614
+ strftime_now ,
615
+ )
616
+
617
+ env = ImmutableSandboxedEnvironment ()
618
+ env .globals ["raise_exception" ] = lambda msg : Exception (f"Error message - { msg } " )
619
+ env .globals ["strftime_now" ] = strftime_now
620
+
621
+ template , bos_token , eos_token = _fetch_and_extract_template (
622
+ model = model ,
623
+ chat_template = chat_template ,
624
+ get_config_fn = _get_tokenizer_config ,
625
+ get_template_fn = _get_chat_template_file ,
626
+ )
627
+ return _render_chat_template (
628
+ env = env ,
629
+ chat_template = template ,
630
+ bos_token = bos_token ,
631
+ eos_token = eos_token ,
632
+ messages = messages ,
633
+ )
634
+
635
+
497
636
def deepseek_r1_pt (messages ):
498
637
return hf_chat_template (
499
638
model = "deepseek-r1/deepseek-r1-7b-instruct" , messages = messages
@@ -4031,33 +4170,9 @@ def prompt_factory(
4031
4170
elif custom_llm_provider == "azure_text" :
4032
4171
return azure_text_pt (messages = messages )
4033
4172
elif custom_llm_provider == "watsonx" :
4034
- if "granite" in model and "chat" in model :
4035
- # granite-13b-chat-v1 and granite-13b-chat-v2 use a specific prompt template
4036
- return ibm_granite_pt (messages = messages )
4037
- elif "ibm-mistral" in model and "instruct" in model :
4038
- # models like ibm-mistral/mixtral-8x7b-instruct-v01-q use the mistral instruct prompt template
4039
- return mistral_instruct_pt (messages = messages )
4040
- elif "meta-llama/llama-3" in model and "instruct" in model :
4041
- # https://llama.meta.com/docs/model-cards-and-prompt-formats/meta-llama-3/
4042
- return custom_prompt (
4043
- role_dict = {
4044
- "system" : {
4045
- "pre_message" : "<|start_header_id|>system<|end_header_id|>\n " ,
4046
- "post_message" : "<|eot_id|>" ,
4047
- },
4048
- "user" : {
4049
- "pre_message" : "<|start_header_id|>user<|end_header_id|>\n " ,
4050
- "post_message" : "<|eot_id|>" ,
4051
- },
4052
- "assistant" : {
4053
- "pre_message" : "<|start_header_id|>assistant<|end_header_id|>\n " ,
4054
- "post_message" : "<|eot_id|>" ,
4055
- },
4056
- },
4057
- messages = messages ,
4058
- initial_prompt_value = "<|begin_of_text|>" ,
4059
- final_prompt_value = "<|start_header_id|>assistant<|end_header_id|>\n " ,
4060
- )
4173
+ from litellm .llms .watsonx .chat .transformation import IBMWatsonXChatConfig
4174
+ return IBMWatsonXChatConfig .apply_prompt_template (model = model , messages = messages )
4175
+
4061
4176
try :
4062
4177
if "meta-llama/llama-2" in model and "chat" in model :
4063
4178
return llama_2_chat_pt (messages = messages )
0 commit comments