@@ -532,6 +532,7 @@ def format_phind(
532532    _prompt  =  _format_add_colon_single (_system_message , _messages , _sep )
533533    return  ChatFormatterResponse (prompt = _prompt )
534534
535+ 
535536@register_chat_format ("intel" ) 
536537def  format_intel (
537538    messages : List [llama_types .ChatCompletionRequestMessage ],
@@ -588,6 +589,7 @@ def format_mistrallite(
588589    _prompt  =  _format_no_colon_single (system_message , _messages , _sep )
589590    return  ChatFormatterResponse (prompt = _prompt )
590591
592+ 
591593@register_chat_format ("chatml" ) 
592594def  format_chatml (
593595    messages : List [llama_types .ChatCompletionRequestMessage ],
@@ -604,6 +606,7 @@ def format_chatml(
604606    _prompt  =  _format_chatml (system_message , _messages , _sep )
605607    return  ChatFormatterResponse (prompt = _prompt , stop = _sep )
606608
609+ 
607610@register_chat_format ("openchat" ) 
608611def  format_openchat (
609612    messages : List [llama_types .ChatCompletionRequestMessage ],
@@ -612,7 +615,9 @@ def format_openchat(
612615    system_template  =  "{system_message}<|end_of_turn|>" 
613616    system_message  =  _get_system_message (messages )
614617    system_message  =  system_template .format (system_message = system_message )
615-     _roles  =  dict (user = "GPT4 Correct User: " , assistant = "<|end_of_turn|>GPT4 Correct Assistant: " )
618+     _roles  =  dict (
619+         user = "GPT4 Correct User: " , assistant = "<|end_of_turn|>GPT4 Correct Assistant: " 
620+     )
616621    _sep  =  "<|end_of_turn|>" 
617622    _messages  =  _map_roles (messages , _roles )
618623    _messages .append ((_roles ["assistant" ], None ))
@@ -651,46 +656,60 @@ def functionary_chat_handler(
651656) ->  Union [llama_types .ChatCompletion , Iterator [llama_types .ChatCompletionChunk ]]:
652657    SYSTEM_MESSAGE  =  """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. The assistant calls functions with appropriate input when necessary""" 
653658
654-     def  generate_type_definition (param : Dict [str , llama_types .JsonType ], indent_level : int , shared_defs ) ->  str :
655-         indent  =  '  '  *  indent_level 
656-         if  '$ref'  in  param :
659+     def  generate_type_definition (
660+         param : Dict [str , llama_types .JsonType ], indent_level : int , shared_defs 
661+     ) ->  str :
662+         indent  =  "  "  *  indent_level 
663+         if  "$ref"  in  param :
657664            # Reference to a shared definition 
658-             ref_name  =  param ['$ref' ].split ('/' )[- 1 ]  # Extract the type name from the reference 
665+             ref_name  =  param ["$ref" ].split ("/" )[
666+                 - 1 
667+             ]  # Extract the type name from the reference 
659668            return  ref_name 
660-         elif  param .get (' type' ==  ' array' 
661-             items  =  param .get (' items' 
669+         elif  param .get (" type" ==  " array" 
670+             items  =  param .get (" items" 
662671            item_type  =  generate_type_definition (items , indent_level  +  1 , shared_defs )
663672            return  f"Array<{ item_type }  
664-         elif  param .get (' type' ==  ' object' 
665-             properties  =  param .get (' properties' 
673+         elif  param .get (" type" ==  " object" 
674+             properties  =  param .get (" properties" 
666675            nested_schema  =  "{\n " 
667676            for  nested_param_name , nested_param  in  properties .items ():
668-                 nested_param_type  =  generate_type_definition (nested_param , indent_level  +  1 , shared_defs )
669-                 nested_schema  +=  f"{ indent } { nested_param_name } { nested_param_type } \n " 
677+                 nested_param_type  =  generate_type_definition (
678+                     nested_param , indent_level  +  1 , shared_defs 
679+                 )
680+                 nested_schema  +=  (
681+                     f"{ indent } { nested_param_name } { nested_param_type } \n " 
682+                 )
670683            nested_schema  +=  indent  +  "}" 
671684            return  nested_schema 
672-         elif  ' enum' in  param :
685+         elif  " enum" in  param :
673686            # Enum type 
674-             return  " | " .join ([f'"{ enum_value }   for  enum_value  in  param [' enum' 
687+             return  " | " .join ([f'"{ enum_value }   for  enum_value  in  param [" enum" 
675688        else :
676689            # Simple type 
677-             return  param .get (' type' ,  ' any' 
690+             return  param .get (" type" ,  " any" 
678691
679692    def  generate_shared_definitions (shared_defs , indent_level : int ) ->  str :
680-         indent  =  '  '  *  indent_level 
693+         indent  =  "  "  *  indent_level 
681694        shared_definitions  =  "" 
682695        for  def_name , def_properties  in  shared_defs .items ():
683696            shared_definitions  +=  f"{ indent } { def_name }  
684-             if  def_properties .get ('type' ) ==  'object' :
685-                 shared_definitions  +=  generate_type_definition (def_properties , indent_level , shared_defs )
686-             elif  'enum'  in  def_properties :
697+             if  def_properties .get ("type" ) ==  "object" :
698+                 shared_definitions  +=  generate_type_definition (
699+                     def_properties , indent_level , shared_defs 
700+                 )
701+             elif  "enum"  in  def_properties :
687702                # Enum type 
688-                 shared_definitions  +=  " | " .join ([f'"{ enum_value }   for  enum_value  in  def_properties ['enum' ]])
703+                 shared_definitions  +=  " | " .join (
704+                     [f'"{ enum_value }   for  enum_value  in  def_properties ["enum" ]]
705+                 )
689706            shared_definitions  +=  ";\n " 
690707        return  shared_definitions 
691708
692709    def  generate_schema_from_functions (functions , namespace = "functions" ) ->  str :
693-         schema  =  "// Supported function definitions that should be called when necessary.\n " 
710+         schema  =  (
711+             "// Supported function definitions that should be called when necessary.\n " 
712+         )
694713        schema  +=  f"namespace { namespace } \n \n " 
695714
696715        # Generate shared definitions 
@@ -706,10 +725,10 @@ def generate_schema_from_functions(functions, namespace="functions") -> str:
706725            description  =  function .get ("description" , "" )
707726            parameters  =  function .get ("parameters" , {})
708727            required_params  =  parameters .get ("required" , [])
709-              
728+ 
710729            schema  +=  f"  // { description } \n " 
711730            schema  +=  f"  type { function_name } \n " 
712-              
731+ 
713732            for  param_name , param  in  parameters .get ("properties" , {}).items ():
714733                param_description  =  param .get ("description" , "" )
715734                param_type  =  generate_type_definition (param , 2 , shared_definitions )
@@ -733,13 +752,18 @@ def prepare_messages_for_inference(
733752                    role = "system" , content = generate_schema_from_functions (functions )
734753                )
735754            )
736-          
755+ 
737756        if  tools  is  not None :
738757            all_messages .append (
739758                llama_types .ChatCompletionRequestSystemMessage (
740-                     role = "system" , content = generate_schema_from_functions (
741-                         [tool ["function" ] for  tool  in  tools  if  tool ["type" ] ==  "function" ]
742-                     )
759+                     role = "system" ,
760+                     content = generate_schema_from_functions (
761+                         [
762+                             tool ["function" ]
763+                             for  tool  in  tools 
764+                             if  tool ["type" ] ==  "function" 
765+                         ]
766+                     ),
743767                )
744768            )
745769
@@ -790,7 +814,9 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
790814                elif  "function_call"  in  msg :
791815                    return  f"assistant to={ msg ['function_call' ]['name' ]} \n { msg ['function_call' ]['arguments' ]} \n " 
792816                elif  "tool_calls"  in  msg  and  len (msg ["tool_calls" ]) >  0 :
793-                     for  tool_call  in  msg ["tool_calls" ]: # NOTE: probably doesn't work with the functionary model 
817+                     for  tool_call  in  msg [
818+                         "tool_calls" 
819+                     ]:  # NOTE: probably doesn't work with the functionary model 
794820                        return  f"assistant to={ tool_call ['id' ]} \n { tool_call ['function' ]['arguments' ]} \n " 
795821                elif  msg ["content" ] is  None :
796822                    return  "assistant" 
@@ -800,12 +826,14 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
800826                raise  ValueError (f"Unsupported role: { msg ['role' ]}  )
801827
802828        return  "" .join ([message_to_str (msg ) for  msg  in  all_messages ])
803-      
829+ 
804830    if  tools  is  not None :
805831        functions  =  [tool ["function" ] for  tool  in  tools  if  tool ["type" ] ==  "function" ]
806-      
832+ 
807833    if  tool_choice  is  not None :
808-         function_call  =  tool_choice  if  isinstance (tool_choice , str ) else  tool_choice ["function" ]
834+         function_call  =  (
835+             tool_choice  if  isinstance (tool_choice , str ) else  tool_choice ["function" ]
836+         )
809837
810838    prompt  =  prepare_messages_for_inference (messages , functions , tools )
811839
@@ -861,19 +889,27 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
861889        if  tool ["type" ] ==  "function"  and  tool ["function" ]["name" ] ==  function_call :
862890            function_body  =  tool ["function" ]["parameters" ]
863891            break 
864-      
892+ 
865893    if  function_body  is  not None :
866894        try :
867895            with  suppress_stdout_stderr (disable = llama .verbose ):
868-                 grammar_text  =  llama_grammar .json_schema_to_gbnf (json .dumps (function_body ))
869-                 grammar  =  llama_grammar .LlamaGrammar .from_string (llama_grammar .json_schema_to_gbnf (json .dumps (function_body )))
896+                 grammar_text  =  llama_grammar .json_schema_to_gbnf (
897+                     json .dumps (function_body )
898+                 )
899+                 grammar  =  llama_grammar .LlamaGrammar .from_string (
900+                     llama_grammar .json_schema_to_gbnf (json .dumps (function_body ))
901+                 )
870902                print (grammar_text )
871903        except  Exception  as  e :
872904            if  llama .verbose :
873-                 print ("Failed to parse function body as JSON schema, falling back to default grammar" )
905+                 print (
906+                     "Failed to parse function body as JSON schema, falling back to default grammar" 
907+                 )
874908                print (e )
875909            with  suppress_stdout_stderr (disable = llama .verbose ):
876-                 grammar  =  llama_grammar .LlamaGrammar .from_string (llama_grammar .JSON_GBNF )
910+                 grammar  =  llama_grammar .LlamaGrammar .from_string (
911+                     llama_grammar .JSON_GBNF 
912+                 )
877913    else :
878914        with  suppress_stdout_stderr (disable = llama .verbose ):
879915            grammar  =  llama_grammar .LlamaGrammar .from_string (llama_grammar .JSON_GBNF )
@@ -929,9 +965,9 @@ def message_to_str(msg: llama_types.ChatCompletionRequestMessage):
929965                            "function" : {
930966                                "name" : function_call ,
931967                                "arguments" : completion ["choices" ][0 ]["text" ],
932-                             }
968+                             }, 
933969                        }
934-                     ]
970+                     ], 
935971                },
936972                "finish_reason" : "tool_calls" ,
937973            }
0 commit comments