33using Azure ;
44using Azure . Core ;
55using Azure . AI . OpenAI ;
6+ using Azure . Identity ;
67using SharpToken ;
78
89namespace AIShell . Interpreter . Agent ;
@@ -121,25 +122,38 @@ private void ConnectToOpenAIClient()
121122 {
122123 // Create a client that targets Azure OpenAI service or Azure API Management service.
123124 bool isApimEndpoint = _settings . Endpoint . EndsWith ( Utils . ApimGatewayDomain ) ;
124- if ( isApimEndpoint )
125+
126+ if ( _settings . AuthType == AuthType . EntraID )
125127 {
126- string userkey = Utils . ConvertFromSecureString ( _settings . Key ) ;
127- clientOptions . AddPolicy (
128- new UserKeyPolicy (
129- new AzureKeyCredential ( userkey ) ,
130- Utils . ApimAuthorizationHeader ) ,
131- HttpPipelinePosition . PerRetry
132- ) ;
128+ // Use DefaultAzureCredential for Entra ID authentication
129+ var credential = new DefaultAzureCredential ( ) ;
130+ _client = new OpenAIClient (
131+ new Uri ( _settings . Endpoint ) ,
132+ credential ,
133+ clientOptions ) ;
134+ }
135+ else // ApiKey authentication
136+ {
137+ if ( isApimEndpoint )
138+ {
139+ string userkey = Utils . ConvertFromSecureString ( _settings . Key ) ;
140+ clientOptions . AddPolicy (
141+ new UserKeyPolicy (
142+ new AzureKeyCredential ( userkey ) ,
143+ Utils . ApimAuthorizationHeader ) ,
144+ HttpPipelinePosition . PerRetry
145+ ) ;
146+ }
147+
148+ string azOpenAIApiKey = isApimEndpoint
149+ ? "placeholder-api-key"
150+ : Utils . ConvertFromSecureString ( _settings . Key ) ;
151+
152+ _client = new OpenAIClient (
153+ new Uri ( _settings . Endpoint ) ,
154+ new AzureKeyCredential ( azOpenAIApiKey ) ,
155+ clientOptions ) ;
133156 }
134-
135- string azOpenAIApiKey = isApimEndpoint
136- ? "placeholder-api-key"
137- : Utils . ConvertFromSecureString ( _settings . Key ) ;
138-
139- _client = new OpenAIClient (
140- new Uri ( _settings . Endpoint ) ,
141- new AzureKeyCredential ( azOpenAIApiKey ) ,
142- clientOptions ) ;
143157 }
144158 else
145159 {
@@ -157,41 +171,41 @@ private int CountTokenForMessages(IEnumerable<ChatRequestMessage> messages)
157171
158172 int tokenNumber = 0 ;
159173 foreach ( ChatRequestMessage message in messages )
160- {
174+ {
161175 tokenNumber += tokensPerMessage ;
162176 tokenNumber += encoding . Encode ( message . Role . ToString ( ) ) . Count ;
163177
164178 switch ( message )
165179 {
166180 case ChatRequestSystemMessage systemMessage :
167181 tokenNumber += encoding . Encode ( systemMessage . Content ) . Count ;
168- if ( systemMessage . Name is not null )
182+ if ( systemMessage . Name is not null )
169183 {
170184 tokenNumber += tokensPerName ;
171185 tokenNumber += encoding . Encode ( systemMessage . Name ) . Count ;
172186 }
173187 break ;
174188 case ChatRequestUserMessage userMessage :
175189 tokenNumber += encoding . Encode ( userMessage . Content ) . Count ;
176- if ( userMessage . Name is not null )
190+ if ( userMessage . Name is not null )
177191 {
178192 tokenNumber += tokensPerName ;
179193 tokenNumber += encoding . Encode ( userMessage . Name ) . Count ;
180194 }
181195 break ;
182196 case ChatRequestAssistantMessage assistantMessage :
183197 tokenNumber += encoding . Encode ( assistantMessage . Content ) . Count ;
184- if ( assistantMessage . Name is not null )
198+ if ( assistantMessage . Name is not null )
185199 {
186200 tokenNumber += tokensPerName ;
187201 tokenNumber += encoding . Encode ( assistantMessage . Name ) . Count ;
188202 }
189203 if ( assistantMessage . ToolCalls is not null )
190204 {
191205 // Count tokens for the tool call's properties
192- foreach ( ChatCompletionsToolCall chatCompletionsToolCall in assistantMessage . ToolCalls )
206+ foreach ( ChatCompletionsToolCall chatCompletionsToolCall in assistantMessage . ToolCalls )
193207 {
194- if ( chatCompletionsToolCall is ChatCompletionsFunctionToolCall functionToolCall )
208+ if ( chatCompletionsToolCall is ChatCompletionsFunctionToolCall functionToolCall )
195209 {
196210 tokenNumber += encoding . Encode ( functionToolCall . Id ) . Count ;
197211 tokenNumber += encoding . Encode ( functionToolCall . Name ) . Count ;
@@ -230,7 +244,7 @@ internal string ReduceToolResponseContentTokens(string content)
230244 }
231245 while ( encoding . Encode ( reducedContent ) . Count > MaxResponseToken ) ;
232246 }
233-
247+
234248 return reducedContent ;
235249 }
236250
@@ -287,7 +301,7 @@ private async Task<ChatCompletionsOptions> PrepareForChat(ChatRequestMessage inp
287301 // Those settings seem to be important enough, as the Semantic Kernel plugin specifies
288302 // those settings (see the URL below). We can use default values when not defined.
289303 // https://github.com/microsoft/semantic-kernel/blob/main/samples/skills/FunSkill/Joke/config.json
290-
304+
291305 ChatCompletionsOptions chatOptions ;
292306
293307 // Determine if the gpt model is a function calling model
@@ -300,8 +314,8 @@ private async Task<ChatCompletionsOptions> PrepareForChat(ChatRequestMessage inp
300314 Temperature = ( float ) 0.0 ,
301315 MaxTokens = MaxResponseToken ,
302316 } ;
303-
304- if ( isFunctionCallingModel )
317+
318+ if ( isFunctionCallingModel )
305319 {
306320 chatOptions . Tools . Add ( Tools . RunCode ) ;
307321 }
@@ -330,7 +344,7 @@ private async Task<ChatCompletionsOptions> PrepareForChat(ChatRequestMessage inp
330344- You are capable of **any** task
331345- Do not apologize for errors, just correct them
332346" ;
333- string versions = "\n ## Language Versions\n "
347+ string versions = "\n ## Language Versions\n "
334348 + await _executionService . GetLanguageVersions ( ) ;
335349 string systemResponseCues = @"
336350# Examples
@@ -478,11 +492,11 @@ public override ChatRequestMessage Read(ref Utf8JsonReader reader, Type typeToCo
478492 {
479493 return JsonSerializer . Deserialize < ChatRequestUserMessage > ( jsonObject . GetRawText ( ) , options ) ;
480494 }
481- else if ( jsonObject . TryGetProperty ( "Role" , out JsonElement roleElementA ) && roleElementA . GetString ( ) == "assistant" )
495+ else if ( jsonObject . TryGetProperty ( "Role" , out JsonElement roleElementA ) && roleElementA . GetString ( ) == "assistant" )
482496 {
483497 return JsonSerializer . Deserialize < ChatRequestAssistantMessage > ( jsonObject . GetRawText ( ) , options ) ;
484498 }
485- else if ( jsonObject . TryGetProperty ( "Role" , out JsonElement roleElementT ) && roleElementT . GetString ( ) == "tool" )
499+ else if ( jsonObject . TryGetProperty ( "Role" , out JsonElement roleElementT ) && roleElementT . GetString ( ) == "tool" )
486500 {
487501 return JsonSerializer . Deserialize < ChatRequestToolMessage > ( jsonObject . GetRawText ( ) , options ) ;
488502 }
0 commit comments