11import logging
22import time
3- from typing import Any , List , Optional
3+ from typing import Any , List , Optional , Union
44
55from pydantic import Field
66from transformers import AutoTokenizer , GPT2TokenizerFast
1212
1313class HFBaseChatModel (AbstractChatModel ):
1414 """
15- Custom LLM Chatbot that can interface with HuggingFace models.
15+ Custom LLM Chatbot that can interface with HuggingFace models with support for multiple samples .
1616
1717 This class allows for the creation of a custom chatbot using models hosted
1818 on HuggingFace Hub or a local checkpoint. It provides flexibility in defining
@@ -22,6 +22,8 @@ class HFBaseChatModel(AbstractChatModel):
2222 Attributes:
2323 llm (Any): The HuggingFaceHub model instance.
2424 prompt_template (Any): Template for the prompt to be used for the model's input sequence.
25+ tokenizer (Any): The tokenizer to use for the model.
26+ n_retry_server (int): Number of times to retry on server failure.
2527 """
2628
2729 llm : Any = Field (description = "The HuggingFaceHub model instance" )
@@ -53,44 +55,56 @@ def __init__(self, model_name, n_retry_server):
5355 def __call__ (
5456 self ,
5557 messages : list [dict ],
56- ) -> dict :
57-
58- # NOTE: The `stop`, `run_manager`, and `kwargs` arguments are ignored in this implementation.
59-
58+ n_samples : int = 1 ,
59+ ) -> Union [AIMessage , List [AIMessage ]]:
60+ """
61+ Generate one or more responses for the given messages.
62+
63+ Args:
64+ messages: List of message dictionaries containing the conversation history.
65+ n_samples: Number of independent responses to generate. Defaults to 1.
66+
67+ Returns:
68+ If n_samples=1, returns a single AIMessage.
69+ If n_samples>1, returns a list of AIMessages.
70+ """
6071 if self .tokenizer :
61- # messages_formated = _convert_messages_to_dict(messages) ## ?
6272 try :
6373 if isinstance (messages , Discussion ):
6474 messages .merge ()
6575 prompt = self .tokenizer .apply_chat_template (messages , tokenize = False )
6676 except Exception as e :
6777 if "Conversation roles must alternate" in str (e ):
6878 logging .warning (
69- f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role"
79+ f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role. "
7080 "Retrying with the 'system' role appended to the 'user' role."
7181 )
7282 messages = _prepend_system_to_first_user (messages )
7383 prompt = self .tokenizer .apply_chat_template (messages , tokenize = False )
7484 else :
7585 raise e
76-
7786 elif self .prompt_template :
7887 prompt = self .prompt_template .construct_prompt (messages )
7988
80- itr = 0
81- while True :
82- try :
83- response = AIMessage (self .llm (prompt ))
84- return response
85- except Exception as e :
86- if itr == self .n_retry_server - 1 :
87- raise e
88- logging .warning (
89- f"Failed to get a response from the server: \n { e } \n "
90- f"Retrying... ({ itr + 1 } /{ self .n_retry_server } )"
91- )
92- time .sleep (5 )
93- itr += 1
89+ responses = []
90+ for _ in range (n_samples ):
91+ itr = 0
92+ while True :
93+ try :
94+ response = AIMessage (self .llm (prompt ))
95+ responses .append (response )
96+ break
97+ except Exception as e :
98+ if itr == self .n_retry_server - 1 :
99+ raise e
100+ logging .warning (
101+ f"Failed to get a response from the server: \n { e } \n "
102+ f"Retrying... ({ itr + 1 } /{ self .n_retry_server } )"
103+ )
104+ time .sleep (5 )
105+ itr += 1
106+
107+ return responses [0 ] if n_samples == 1 else responses
94108
95109 def _llm_type (self ):
96110 return "huggingface"
0 commit comments