@@ -36,6 +36,7 @@ def __init__(
3636 self ,
3737 model_name : str ,
3838 model_params : Optional [dict [str , Any ]] = None ,
39+ system_instruction : Optional [str ] = None ,
3940 ):
4041 """
4142 Base class for OpenAI LLM.
@@ -54,7 +55,7 @@ def __init__(
5455 "Please install it with `pip install openai`."
5556 )
5657 self .openai = openai
57- super ().__init__ (model_name , model_params )
58+ super ().__init__ (model_name , model_params , system_instruction )
5859
5960 def get_messages (
6061 self ,
@@ -64,6 +65,32 @@ def get_messages(
6465 {"role" : "system" , "content" : input },
6566 ]
6667
68+ def get_conversation_history (
69+ self ,
70+ input : str ,
71+ chat_history : list [str ],
72+ ) -> Iterable [ChatCompletionMessageParam ]:
73+ messages = [{"role" : "system" , "content" : self .system_instruction }]
74+ for i , message in enumerate (chat_history ):
75+ if i % 2 == 0 :
76+ messages .append ({"role" : "user" , "content" : message })
77+ else :
78+ messages .append ({"role" : "assistant" , "content" : message })
79+ messages .append ({"role" : "user" , "content" : input })
80+ return messages
81+
82+ def chat (self , input : str , chat_history : list [str ]) -> LLMResponse :
83+ try :
84+ response = self .client .chat .completions .create (
85+ messages = self .get_conversation_history (input , chat_history ),
86+ model = self .model_name ,
87+ ** self .model_params ,
88+ )
89+ content = response .choices [0 ].message .content or ""
90+ return LLMResponse (content = content )
91+ except self .openai .OpenAIError as e :
92+ raise LLMGenerationError (e )
93+
6794 def invoke (self , input : str ) -> LLMResponse :
6895 """Sends a text input to the OpenAI chat completion model
6996 and returns the response's content.
@@ -118,6 +145,7 @@ def __init__(
118145 self ,
119146 model_name : str ,
120147 model_params : Optional [dict [str , Any ]] = None ,
148+ system_instruction : Optional [str ] = None ,
121149 ** kwargs : Any ,
122150 ):
123151 """OpenAI LLM
@@ -129,7 +157,7 @@ def __init__(
129157 model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
130158 kwargs: All other parameters will be passed to the openai.OpenAI init.
131159 """
132- super ().__init__ (model_name , model_params )
160+ super ().__init__ (model_name , model_params , system_instruction )
133161 self .client = self .openai .OpenAI (** kwargs )
134162 self .async_client = self .openai .AsyncOpenAI (** kwargs )
135163
@@ -139,6 +167,7 @@ def __init__(
139167 self ,
140168 model_name : str ,
141169 model_params : Optional [dict [str , Any ]] = None ,
170+ system_instruction : Optional [str ] = None ,
142171 ** kwargs : Any ,
143172 ):
144173 """Azure OpenAI LLM. Use this class when using an OpenAI model
@@ -149,6 +178,6 @@ def __init__(
149178 model_params (str): Parameters like temperature that will be passed to the model when text is sent to it
150179 kwargs: All other parameters will be passed to the openai.OpenAI init.
151180 """
152- super ().__init__ (model_name , model_params )
181+ super ().__init__ (model_name , model_params , system_instruction )
153182 self .client = self .openai .AzureOpenAI (** kwargs )
154183 self .async_client = self .openai .AsyncAzureOpenAI (** kwargs )
0 commit comments