Skip to content

Commit 3d95100

Browse files
committed
Merge branch 'ollama-multy-model-support' of github.com:cnupy/AIShell into ollama-multy-model-support
2 parents c18b187 + 4279c88 commit 3d95100

File tree

17 files changed

+252
-86
lines changed

17 files changed

+252
-86
lines changed

shell/AIShell.Integration/AIShell.psd1

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
@{
22
RootModule = 'AIShell.psm1'
33
NestedModules = @("AIShell.Integration.dll")
4-
ModuleVersion = '1.0.2'
4+
ModuleVersion = '1.0.3'
55
GUID = 'ECB8BEE0-59B9-4DAE-9D7B-A990B480279A'
66
Author = 'Microsoft Corporation'
77
CompanyName = 'Microsoft Corporation'
@@ -13,5 +13,5 @@
1313
VariablesToExport = '*'
1414
AliasesToExport = @('aish', 'askai', 'fixit')
1515
HelpInfoURI = 'https://aka.ms/aishell-help'
16-
PrivateData = @{ PSData = @{ Prerelease = 'preview2'; ProjectUri = 'https://github.com/PowerShell/AIShell' } }
16+
PrivateData = @{ PSData = @{ Prerelease = 'preview3'; ProjectUri = 'https://github.com/PowerShell/AIShell' } }
1717
}

shell/Markdown.VT/ColorCode.VT/Parser/Bash.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,20 @@ public class Bash : ILanguage
2424
{1, BashCommentScope}
2525
}),
2626

27+
new LanguageRule(
28+
@"'[^\n]*?'",
29+
new Dictionary<int, string>
30+
{
31+
{0, ScopeName.String}
32+
}),
33+
34+
new LanguageRule(
35+
@"""[^\n]*?(?<!\\)""",
36+
new Dictionary<int, string>
37+
{
38+
{0, ScopeName.String}
39+
}),
40+
2741
// match the first word of a line in a multi-line string as the command name.
2842
new LanguageRule(
2943
@"(?m)^\s*(\w+)",

shell/Markdown.VT/ColorCode.VT/Parser/PowerShell.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public IList<LanguageRule> Rules
4848
}),
4949

5050
new LanguageRule(
51-
@"'[^\n]*?(?<!\\)'",
51+
@"'[^\n]*?'",
5252
new Dictionary<int, string>
5353
{
5454
{0, ScopeName.String}
@@ -62,7 +62,7 @@ public IList<LanguageRule> Rules
6262
}),
6363

6464
new LanguageRule(
65-
@"(?s)(""[^\n]*?(?<!`)"")",
65+
@"""[^\n]*?(?<!`)""",
6666
new Dictionary<int, string>
6767
{
6868
{0, ScopeName.String}

shell/agents/AIShell.Interpreter.Agent/AIShell.Interpreter.Agent.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
<ItemGroup>
1919
<PackageReference Include="Azure.AI.OpenAI" Version="1.0.0-beta.13" />
20-
<PackageReference Include="Azure.Core" Version="1.37.0" />
20+
<PackageReference Include="Azure.Identity" Version="1.13.2" />
21+
<PackageReference Include="Azure.Core" Version="1.44.1" />
2122
<PackageReference Include="SharpToken" Version="2.0.3" />
2223
</ItemGroup>
2324

shell/agents/AIShell.Interpreter.Agent/Agent.cs

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,7 @@ private void OnSettingFileChange(object sender, FileSystemEventArgs e)
236236

237237
private void NewExampleSettingFile()
238238
{
239-
string SampleContent = """
239+
string sample = $$"""
240240
{
241241
// To use the Azure OpenAI service:
242242
// - Set `Endpoint` to the endpoint of your Azure OpenAI service,
@@ -249,22 +249,37 @@ private void NewExampleSettingFile()
249249
"Deployment": "",
250250
"ModelName": "",
251251
"Key": "",
252+
"AuthType": "ApiKey",
252253
"AutoExecution": false, // 'true' to allow the agent run code automatically; 'false' to always prompt before running code.
253254
"DisplayErrors": true // 'true' to display the errors when running code; 'false' to hide the errors to be less verbose.
254255
256+
// To use Azure OpenAI service with Entra ID authentication:
257+
// - Set `Endpoint` to the endpoint of your Azure OpenAI service.
258+
// - Set `Deployment` to the deployment name of your Azure OpenAI service.
259+
// - Set `ModelName` to the name of the model used for your deployment.
260+
// - Set `AuthType` to "EntraID" to use Azure AD credentials.
261+
/*
262+
"Endpoint": "<insert your Azure OpenAI endpoint>",
263+
"Deployment": "<insert your deployment name>",
264+
"ModelName": "<insert the model name>",
265+
"AuthType": "EntraID",
266+
"AutoExecution": false,
267+
"DisplayErrors": true
268+
*/
269+
255270
// To use the public OpenAI service:
256271
// - Ignore the `Endpoint` and `Deployment` keys.
257272
// - Set `ModelName` to the name of the model to be used. e.g. "gpt-4o".
258273
// - Set `Key` to be the OpenAI access token.
259-
// Replace the above with the following:
260274
/*
261-
"ModelName": "",
262-
"Key": "",
275+
"ModelName": "<insert the model name>",
276+
"Key": "<insert your key>",
277+
"AuthType": "ApiKey",
263278
"AutoExecution": false,
264279
"DisplayErrors": true
265280
*/
266281
}
267282
""";
268-
File.WriteAllText(SettingFile, SampleContent, Encoding.UTF8);
283+
File.WriteAllText(SettingFile, sample);
269284
}
270285
}

shell/agents/AIShell.Interpreter.Agent/Service.cs

Lines changed: 44 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using Azure;
44
using Azure.Core;
55
using Azure.AI.OpenAI;
6+
using Azure.Identity;
67
using SharpToken;
78

89
namespace AIShell.Interpreter.Agent;
@@ -121,25 +122,38 @@ private void ConnectToOpenAIClient()
121122
{
122123
// Create a client that targets Azure OpenAI service or Azure API Management service.
123124
bool isApimEndpoint = _settings.Endpoint.EndsWith(Utils.ApimGatewayDomain);
124-
if (isApimEndpoint)
125+
126+
if (_settings.AuthType == AuthType.EntraID)
125127
{
126-
string userkey = Utils.ConvertFromSecureString(_settings.Key);
127-
clientOptions.AddPolicy(
128-
new UserKeyPolicy(
129-
new AzureKeyCredential(userkey),
130-
Utils.ApimAuthorizationHeader),
131-
HttpPipelinePosition.PerRetry
132-
);
128+
// Use DefaultAzureCredential for Entra ID authentication
129+
var credential = new DefaultAzureCredential();
130+
_client = new OpenAIClient(
131+
new Uri(_settings.Endpoint),
132+
credential,
133+
clientOptions);
134+
}
135+
else // ApiKey authentication
136+
{
137+
if (isApimEndpoint)
138+
{
139+
string userkey = Utils.ConvertFromSecureString(_settings.Key);
140+
clientOptions.AddPolicy(
141+
new UserKeyPolicy(
142+
new AzureKeyCredential(userkey),
143+
Utils.ApimAuthorizationHeader),
144+
HttpPipelinePosition.PerRetry
145+
);
146+
}
147+
148+
string azOpenAIApiKey = isApimEndpoint
149+
? "placeholder-api-key"
150+
: Utils.ConvertFromSecureString(_settings.Key);
151+
152+
_client = new OpenAIClient(
153+
new Uri(_settings.Endpoint),
154+
new AzureKeyCredential(azOpenAIApiKey),
155+
clientOptions);
133156
}
134-
135-
string azOpenAIApiKey = isApimEndpoint
136-
? "placeholder-api-key"
137-
: Utils.ConvertFromSecureString(_settings.Key);
138-
139-
_client = new OpenAIClient(
140-
new Uri(_settings.Endpoint),
141-
new AzureKeyCredential(azOpenAIApiKey),
142-
clientOptions);
143157
}
144158
else
145159
{
@@ -157,41 +171,41 @@ private int CountTokenForMessages(IEnumerable<ChatRequestMessage> messages)
157171

158172
int tokenNumber = 0;
159173
foreach (ChatRequestMessage message in messages)
160-
{
174+
{
161175
tokenNumber += tokensPerMessage;
162176
tokenNumber += encoding.Encode(message.Role.ToString()).Count;
163177

164178
switch (message)
165179
{
166180
case ChatRequestSystemMessage systemMessage:
167181
tokenNumber += encoding.Encode(systemMessage.Content).Count;
168-
if(systemMessage.Name is not null)
182+
if (systemMessage.Name is not null)
169183
{
170184
tokenNumber += tokensPerName;
171185
tokenNumber += encoding.Encode(systemMessage.Name).Count;
172186
}
173187
break;
174188
case ChatRequestUserMessage userMessage:
175189
tokenNumber += encoding.Encode(userMessage.Content).Count;
176-
if(userMessage.Name is not null)
190+
if (userMessage.Name is not null)
177191
{
178192
tokenNumber += tokensPerName;
179193
tokenNumber += encoding.Encode(userMessage.Name).Count;
180194
}
181195
break;
182196
case ChatRequestAssistantMessage assistantMessage:
183197
tokenNumber += encoding.Encode(assistantMessage.Content).Count;
184-
if(assistantMessage.Name is not null)
198+
if (assistantMessage.Name is not null)
185199
{
186200
tokenNumber += tokensPerName;
187201
tokenNumber += encoding.Encode(assistantMessage.Name).Count;
188202
}
189203
if (assistantMessage.ToolCalls is not null)
190204
{
191205
// Count tokens for the tool call's properties
192-
foreach(ChatCompletionsToolCall chatCompletionsToolCall in assistantMessage.ToolCalls)
206+
foreach (ChatCompletionsToolCall chatCompletionsToolCall in assistantMessage.ToolCalls)
193207
{
194-
if(chatCompletionsToolCall is ChatCompletionsFunctionToolCall functionToolCall)
208+
if (chatCompletionsToolCall is ChatCompletionsFunctionToolCall functionToolCall)
195209
{
196210
tokenNumber += encoding.Encode(functionToolCall.Id).Count;
197211
tokenNumber += encoding.Encode(functionToolCall.Name).Count;
@@ -230,7 +244,7 @@ internal string ReduceToolResponseContentTokens(string content)
230244
}
231245
while (encoding.Encode(reducedContent).Count > MaxResponseToken);
232246
}
233-
247+
234248
return reducedContent;
235249
}
236250

@@ -287,7 +301,7 @@ private async Task<ChatCompletionsOptions> PrepareForChat(ChatRequestMessage inp
287301
// Those settings seem to be important enough, as the Semantic Kernel plugin specifies
288302
// those settings (see the URL below). We can use default values when not defined.
289303
// https://github.com/microsoft/semantic-kernel/blob/main/samples/skills/FunSkill/Joke/config.json
290-
304+
291305
ChatCompletionsOptions chatOptions;
292306

293307
// Determine if the gpt model is a function calling model
@@ -300,8 +314,8 @@ private async Task<ChatCompletionsOptions> PrepareForChat(ChatRequestMessage inp
300314
Temperature = (float)0.0,
301315
MaxTokens = MaxResponseToken,
302316
};
303-
304-
if(isFunctionCallingModel)
317+
318+
if (isFunctionCallingModel)
305319
{
306320
chatOptions.Tools.Add(Tools.RunCode);
307321
}
@@ -330,7 +344,7 @@ private async Task<ChatCompletionsOptions> PrepareForChat(ChatRequestMessage inp
330344
- You are capable of **any** task
331345
- Do not apologize for errors, just correct them
332346
";
333-
string versions = "\n## Language Versions\n"
347+
string versions = "\n## Language Versions\n"
334348
+ await _executionService.GetLanguageVersions();
335349
string systemResponseCues = @"
336350
# Examples
@@ -478,11 +492,11 @@ public override ChatRequestMessage Read(ref Utf8JsonReader reader, Type typeToCo
478492
{
479493
return JsonSerializer.Deserialize<ChatRequestUserMessage>(jsonObject.GetRawText(), options);
480494
}
481-
else if(jsonObject.TryGetProperty("Role", out JsonElement roleElementA) && roleElementA.GetString() == "assistant")
495+
else if (jsonObject.TryGetProperty("Role", out JsonElement roleElementA) && roleElementA.GetString() == "assistant")
482496
{
483497
return JsonSerializer.Deserialize<ChatRequestAssistantMessage>(jsonObject.GetRawText(), options);
484498
}
485-
else if(jsonObject.TryGetProperty("Role", out JsonElement roleElementT) && roleElementT.GetString() == "tool")
499+
else if (jsonObject.TryGetProperty("Role", out JsonElement roleElementT) && roleElementT.GetString() == "tool")
486500
{
487501
return JsonSerializer.Deserialize<ChatRequestToolMessage>(jsonObject.GetRawText(), options);
488502
}

shell/agents/AIShell.Interpreter.Agent/Settings.cs

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ internal enum EndpointType
1212
OpenAI,
1313
}
1414

15+
public enum AuthType
16+
{
17+
ApiKey,
18+
EntraID
19+
}
20+
1521
internal class Settings
1622
{
1723
internal EndpointType Type { get; }
@@ -23,6 +29,8 @@ internal class Settings
2329
public string ModelName { set; get; }
2430
public SecureString Key { set; get; }
2531

32+
public AuthType AuthType { set; get; } = AuthType.ApiKey;
33+
2634
public bool AutoExecution { set; get; }
2735
public bool DisplayErrors { set; get; }
2836

@@ -36,6 +44,7 @@ public Settings(ConfigData configData)
3644
AutoExecution = configData.AutoExecution ?? false;
3745
DisplayErrors = configData.DisplayErrors ?? true;
3846
Key = configData.Key;
47+
AuthType = configData.AuthType;
3948

4049
Dirty = false;
4150
ModelInfo = ModelInfo.TryResolve(ModelName, out var model) ? model : null;
@@ -47,6 +56,12 @@ public Settings(ConfigData configData)
4756
: !noEndpoint && !noDeployment
4857
? EndpointType.AzureOpenAI
4958
: throw new InvalidOperationException($"Invalid setting: {(noEndpoint ? "Endpoint" : "Deployment")} key is missing. To use Azure OpenAI service, please specify both the 'Endpoint' and 'Deployment' keys. To use OpenAI service, please ignore both keys.");
59+
60+
// EntraID authentication is only supported for Azure OpenAI
61+
if (AuthType == AuthType.EntraID && Type != EndpointType.AzureOpenAI)
62+
{
63+
throw new InvalidOperationException("EntraID authentication is only supported for Azure OpenAI service.");
64+
}
5065
}
5166

5267
internal void MarkClean()
@@ -60,7 +75,7 @@ internal void MarkClean()
6075
/// <returns></returns>
6176
internal async Task<bool> SelfCheck(IHost host, CancellationToken token)
6277
{
63-
if (Key is not null && ModelInfo is not null)
78+
if ((AuthType is AuthType.EntraID || Key is not null) && ModelInfo is not null)
6479
{
6580
return true;
6681
}
@@ -76,7 +91,7 @@ internal async Task<bool> SelfCheck(IHost host, CancellationToken token)
7691
await AskForModel(host, token);
7792
}
7893

79-
if (Key is null)
94+
if (AuthType == AuthType.ApiKey && Key is null)
8095
{
8196
await AskForKeyAsync(host, token);
8297
}
@@ -101,12 +116,14 @@ private void ShowEndpointInfo(IHost host)
101116
new(label: " Endpoint", m => m.Endpoint),
102117
new(label: " Deployment", m => m.Deployment),
103118
new(label: " Model", m => m.ModelName),
119+
new(label: " Auth Type", m => m.AuthType.ToString()),
104120
],
105121

106122
EndpointType.OpenAI =>
107123
[
108124
new(label: " Type", m => m.Type.ToString()),
109125
new(label: " Model", m => m.ModelName),
126+
new(label: " Auth Type", m => m.AuthType.ToString()),
110127
],
111128

112129
_ => throw new UnreachableException(),
@@ -156,6 +173,7 @@ internal ConfigData ToConfigData()
156173
ModelName = this.ModelName,
157174
AutoExecution = this.AutoExecution,
158175
DisplayErrors = this.DisplayErrors,
176+
AuthType = this.AuthType,
159177
Key = this.Key,
160178
};
161179
}
@@ -166,6 +184,7 @@ internal class ConfigData
166184
public string Endpoint { set; get; }
167185
public string Deployment { set; get; }
168186
public string ModelName { set; get; }
187+
public AuthType AuthType { set; get; } = AuthType.ApiKey;
169188
public bool? AutoExecution { set; get; }
170189
public bool? DisplayErrors { set; get; }
171190

shell/agents/AIShell.OpenAI.Agent/AIShell.OpenAI.Agent.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
<ItemGroup>
2424
<PackageReference Include="Azure.AI.OpenAI" Version="2.1.0" />
25+
<PackageReference Include="Azure.Identity" Version="1.13.2" />
2526
<PackageReference Include="Microsoft.ML.Tokenizers" Version="1.0.1" />
2627
<PackageReference Include="Microsoft.ML.Tokenizers.Data.O200kBase" Version="1.0.1" />
2728
<PackageReference Include="Microsoft.ML.Tokenizers.Data.Cl100kBase" Version="1.0.1" />

0 commit comments

Comments
 (0)