diff --git a/shell/agents/Microsoft.Azure.Agent/AzureAgent.cs b/shell/agents/Microsoft.Azure.Agent/AzureAgent.cs index c0fd3a88..22ed2f34 100644 --- a/shell/agents/Microsoft.Azure.Agent/AzureAgent.cs +++ b/shell/agents/Microsoft.Azure.Agent/AzureAgent.cs @@ -35,13 +35,17 @@ 7. DO NOT include the placeholder summary when the commands contains no placehol private int _turnsLeft; private readonly string _instructions; private readonly StringBuilder _buffer; + private readonly HttpClient _httpClient; private readonly ChatSession _chatSession; private readonly Dictionary _valueStore; public AzureAgent() { _buffer = new StringBuilder(); - _chatSession = new ChatSession(); + _httpClient = new HttpClient(); + Task.Run(() => DataRetriever.WarmUpMetadataService(_httpClient)); + + _chatSession = new ChatSession(_httpClient); _valueStore = new Dictionary(StringComparer.OrdinalIgnoreCase); _instructions = string.Format(InstructionPrompt, Environment.OSVersion.VersionString); @@ -66,7 +70,9 @@ public AzureAgent() public void Dispose() { - _chatSession?.Dispose(); + ArgPlaceholder?.DataRetriever?.Dispose(); + _chatSession.Dispose(); + _httpClient.Dispose(); } public void Initialize(AgentConfig config) @@ -126,7 +132,7 @@ public async Task ChatAsync(string input, IShell shell) string answer = data is null ? copilotResponse.Text : GenerateAnswer(data); if (data?.PlaceholderSet is not null) { - ArgPlaceholder = new ArgumentPlaceholder(input, data); + ArgPlaceholder = new ArgumentPlaceholder(input, data, _httpClient); } host.RenderFullResponse(answer); diff --git a/shell/agents/Microsoft.Azure.Agent/ChatSession.cs b/shell/agents/Microsoft.Azure.Agent/ChatSession.cs index 0a84b82e..cad2484d 100644 --- a/shell/agents/Microsoft.Azure.Agent/ChatSession.cs +++ b/shell/agents/Microsoft.Azure.Agent/ChatSession.cs @@ -23,10 +23,10 @@ internal class ChatSession : IDisposable private readonly HttpClient _httpClient; private readonly Dictionary _flights; - internal ChatSession() + internal ChatSession(HttpClient httpClient) { _dl_secret = Environment.GetEnvironmentVariable("DL_SECRET"); - _httpClient = new HttpClient(); + _httpClient = httpClient; // Keys and values for flights are from the portal request. _flights = new Dictionary() @@ -199,14 +199,14 @@ private HttpRequestMessage PrepareForChat(string input) text = input, attachments = new object[] { new { - contentType = "application/json", + contentType = Utils.JsonContentType, name = "azurecopilot/clienthandlerdefinitions", content = new { clientHandlers = Array.Empty() } }, new { - contentType = "application/json", + contentType = Utils.JsonContentType, name = "azurecopilot/viewcontext", content = new { viewContext = new { @@ -217,7 +217,7 @@ private HttpRequestMessage PrepareForChat(string input) } }, new { - contentType = "application/json", + contentType = Utils.JsonContentType, name = "azurecopilot/flights", content = new { flights = _flights @@ -227,7 +227,7 @@ private HttpRequestMessage PrepareForChat(string input) }; var json = JsonSerializer.Serialize(requestData, Utils.JsonOptions); - var content = new StringContent(json, Encoding.UTF8, "application/json"); + var content = new StringContent(json, Encoding.UTF8, Utils.JsonContentType); var request = new HttpRequestMessage(HttpMethod.Post, _conversationUrl) { Content = content }; request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _token); @@ -305,7 +305,6 @@ internal async Task GetChatResponseAsync(string input, IStatusC public void Dispose() { - _httpClient.Dispose(); _copilotReceiver?.Dispose(); } } diff --git a/shell/agents/Microsoft.Azure.Agent/DataRetriever.cs b/shell/agents/Microsoft.Azure.Agent/DataRetriever.cs index 2b8ec8c1..6ba09779 100644 --- a/shell/agents/Microsoft.Azure.Agent/DataRetriever.cs +++ b/shell/agents/Microsoft.Azure.Agent/DataRetriever.cs @@ -1,6 +1,7 @@ using System.Collections.Concurrent; using System.ComponentModel; using System.Diagnostics; +using System.Text; using System.Text.Json; using System.Text.RegularExpressions; using AIShell.Abstraction; @@ -9,11 +10,14 @@ namespace Microsoft.Azure.Agent; internal class DataRetriever : IDisposable { + private const string MetadataQueryTemplate = "{{\"command\":\"{0}\"}}"; + private const string MetadataEndpoint = "https://cli-validation-tool-meta-qry.azurewebsites.net/api/command_metadata"; + private static readonly Dictionary s_azNamingRules; - private static readonly ConcurrentDictionary s_azStaticDataCache; + private static readonly ConcurrentDictionary s_azStaticDataCache; - private readonly string _staticDataRoot; private readonly Task _rootTask; + private readonly HttpClient _httpClient; private readonly SemaphoreSlim _semaphore; private readonly List _placeholders; private readonly Dictionary _placeholderMap; @@ -302,11 +306,11 @@ static DataRetriever() s_azStaticDataCache = new(StringComparer.OrdinalIgnoreCase); } - internal DataRetriever(ResponseData data) + internal DataRetriever(ResponseData data, HttpClient httpClient) { _stop = false; + _httpClient = httpClient; _semaphore = new SemaphoreSlim(3, 3); - _staticDataRoot = @"E:\yard\tmp\az-cli-out\az"; _placeholders = new(capacity: data.PlaceholderSet.Count); _placeholderMap = new(capacity: data.PlaceholderSet.Count); @@ -453,31 +457,23 @@ private ArgumentInfo CreateArgInfo(ArgumentPair pair) private List GetArgValues(ArgumentPair pair) { // First, try to get static argument values if they exist. + bool hasCompleter = true; string command = pair.Command; - if (!s_azStaticDataCache.TryGetValue(command, out Command commandData)) + + AzCLICommand commandData = s_azStaticDataCache.GetOrAdd(command, QueryForMetadata); + AzCLIParameter param = commandData?.FindParameter(pair.Parameter); + + if (param is not null) { - string[] cmdElements = command.Split(' ', StringSplitOptions.RemoveEmptyEntries); - string dirPath = _staticDataRoot; - for (int i = 1; i < cmdElements.Length - 1; i++) + if (param.Choices?.Count > 0) { - dirPath = Path.Combine(dirPath, cmdElements[i]); + return param.Choices; } - string filePath = Path.Combine(dirPath, cmdElements[^1] + ".json"); - commandData = File.Exists(filePath) - ? JsonSerializer.Deserialize(File.OpenRead(filePath)) - : null; - s_azStaticDataCache.TryAdd(command, commandData); - } - - Option option = commandData?.FindOption(pair.Parameter); - List staticValues = option?.Arguments; - if (staticValues?.Count > 0) - { - return staticValues; + hasCompleter = param.HasCompleter; } - if (_stop) { return null; } + if (_stop || !hasCompleter) { return null; } // Then, try to get dynamic argument values using AzCLI tab completion. string commandLine = $"{pair.Command} {pair.Parameter} "; @@ -551,6 +547,42 @@ private List GetArgValues(ArgumentPair pair) } } + private AzCLICommand QueryForMetadata(string azCommand) + { + AzCLICommand command = null; + var reqBody = new StringContent(string.Format(MetadataQueryTemplate, azCommand), Encoding.UTF8, Utils.JsonContentType); + var request = new HttpRequestMessage(HttpMethod.Get, MetadataEndpoint) { Content = reqBody }; + + try + { + using var cts = new CancellationTokenSource(1200); + var response = _httpClient.Send(request, HttpCompletionOption.ResponseHeadersRead, cts.Token); + + if (response.IsSuccessStatusCode) + { + using Stream stream = response.Content.ReadAsStream(cts.Token); + using JsonDocument document = JsonDocument.Parse(stream); + + JsonElement root = document.RootElement; + if (root.TryGetProperty("data", out JsonElement data) && + data.TryGetProperty("metadata", out JsonElement metadata)) + { + command = metadata.Deserialize(Utils.JsonOptions); + } + } + else + { + // TODO: telemetry. + } + } + catch (Exception) + { + // TODO: telemetry. + } + + return command; + } + internal (string command, string parameter) GetMappedCommand(string placeholderName) { if (_placeholderMap.TryGetValue(placeholderName, out ArgumentPair pair)) @@ -585,6 +617,22 @@ public void Dispose() _rootTask.Wait(); _semaphore.Dispose(); } + + internal static void WarmUpMetadataService(HttpClient httpClient) + { + // Send a request to the AzCLI metadata service to warm up the service (code start is slow). + // We query for the command 'az sql server list' which only has 2 parameters, + // so it should cause minimum processing on the server side. + HttpRequestMessage request = new(HttpMethod.Get, MetadataEndpoint) + { + Content = new StringContent( + "{\"command\":\"az sql server list\"}", + Encoding.UTF8, + Utils.JsonContentType) + }; + + _ = httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead); + } } internal class ArgumentPair @@ -703,83 +751,3 @@ internal bool TryMatchName(string name, out string prodName, out string envName) return false; } } - -public class Option -{ - public string Name { get; } - public string[] Alias { get; } - public string[] Short { get; } - public string Attribute { get; } - public string Description { get; set; } - public List Arguments { get; set; } - - public Option(string name, string description, string[] alias, string[] @short, string attribute, List arguments) - { - ArgumentException.ThrowIfNullOrEmpty(name); - ArgumentException.ThrowIfNullOrEmpty(description); - - Name = name; - Alias = alias; - Short = @short; - Attribute = attribute; - Description = description; - Arguments = arguments; - } -} - -public sealed class Command -{ - public List