11import logging
22import time
3- from typing import Any , List , Optional
3+ from typing import Any , List , Optional , Union
44
55from pydantic import Field
66from transformers import AutoTokenizer , GPT2TokenizerFast
1212
1313class HFBaseChatModel (AbstractChatModel ):
1414 """
15- Custom LLM Chatbot that can interface with HuggingFace models.
15+ Custom LLM Chatbot that can interface with HuggingFace models with support for multiple samples .
1616
1717 This class allows for the creation of a custom chatbot using models hosted
1818 on HuggingFace Hub or a local checkpoint. It provides flexibility in defining
@@ -22,6 +22,8 @@ class HFBaseChatModel(AbstractChatModel):
2222 Attributes:
2323 llm (Any): The HuggingFaceHub model instance.
2424 prompt_template (Any): Template for the prompt to be used for the model's input sequence.
25+ tokenizer (Any): The tokenizer to use for the model.
26+ n_retry_server (int): Number of times to retry on server failure.
2527 """
2628
2729 llm : Any = Field (description = "The HuggingFaceHub model instance" )
@@ -53,44 +55,62 @@ def __init__(self, model_name, n_retry_server):
5355 def __call__ (
5456 self ,
5557 messages : list [dict ],
56- ) -> dict :
57-
58- # NOTE: The `stop`, `run_manager`, and `kwargs` arguments are ignored in this implementation.
59-
58+ n_samples : int = 1 ,
59+ temperature : float = None ,
60+ ) -> Union [AIMessage , List [AIMessage ]]:
61+ """
62+ Generate one or more responses for the given messages.
63+
64+ Args:
65+ messages: List of message dictionaries containing the conversation history.
66+ n_samples: Number of independent responses to generate. Defaults to 1.
67+ temperature: The temperature for response sampling. Defaults to None.
68+
69+ Returns:
70+ If n_samples=1, returns a single AIMessage.
71+ If n_samples>1, returns a list of AIMessages.
72+
73+ Raises:
74+ Exception: If the server fails to respond after n_retry_server attempts or if the chat template fails.
75+ """
6076 if self .tokenizer :
61- # messages_formated = _convert_messages_to_dict(messages) ## ?
6277 try :
6378 if isinstance (messages , Discussion ):
6479 messages .merge ()
6580 prompt = self .tokenizer .apply_chat_template (messages , tokenize = False )
6681 except Exception as e :
6782 if "Conversation roles must alternate" in str (e ):
6883 logging .warning (
69- f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role"
84+ f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role. "
7085 "Retrying with the 'system' role appended to the 'user' role."
7186 )
7287 messages = _prepend_system_to_first_user (messages )
7388 prompt = self .tokenizer .apply_chat_template (messages , tokenize = False )
7489 else :
7590 raise e
76-
7791 elif self .prompt_template :
7892 prompt = self .prompt_template .construct_prompt (messages )
7993
80- itr = 0
81- while True :
82- try :
83- response = AIMessage (self .llm (prompt ))
84- return response
85- except Exception as e :
86- if itr == self .n_retry_server - 1 :
87- raise e
88- logging .warning (
89- f"Failed to get a response from the server: \n { e } \n "
90- f"Retrying... ({ itr + 1 } /{ self .n_retry_server } )"
91- )
92- time .sleep (5 )
93- itr += 1
94+ responses = []
95+ for _ in range (n_samples ):
96+ itr = 0
97+ while True :
98+ try :
99+ temperature = temperature if temperature is not None else self .temperature
100+ response = AIMessage (self .llm (prompt , temperature = temperature ))
101+ responses .append (response )
102+ break
103+ except Exception as e :
104+ if itr == self .n_retry_server - 1 :
105+ raise e
106+ logging .warning (
107+ f"Failed to get a response from the server: \n { e } \n "
108+ f"Retrying... ({ itr + 1 } /{ self .n_retry_server } )"
109+ )
110+ time .sleep (5 )
111+ itr += 1
112+
113+ return responses [0 ] if n_samples == 1 else responses
94114
95115 def _llm_type (self ):
96116 return "huggingface"
0 commit comments