Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions shell/agents/Microsoft.Azure.Agent/AzureAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,17 @@ 7. DO NOT include the placeholder summary when the commands contains no placehol
private int _turnsLeft;
private readonly string _instructions;
private readonly StringBuilder _buffer;
private readonly HttpClient _httpClient;
private readonly ChatSession _chatSession;
private readonly Dictionary<string, string> _valueStore;

public AzureAgent()
{
_buffer = new StringBuilder();
_chatSession = new ChatSession();
_httpClient = new HttpClient();
Task.Run(() => DataRetriever.WarmUpMetadataService(_httpClient));

_chatSession = new ChatSession(_httpClient);
_valueStore = new Dictionary<string, string>(StringComparer.OrdinalIgnoreCase);
_instructions = string.Format(InstructionPrompt, Environment.OSVersion.VersionString);

Expand All @@ -66,7 +70,9 @@ public AzureAgent()

public void Dispose()
{
_chatSession?.Dispose();
ArgPlaceholder?.DataRetriever?.Dispose();
_chatSession.Dispose();
_httpClient.Dispose();
}

public void Initialize(AgentConfig config)
Expand Down Expand Up @@ -126,7 +132,7 @@ public async Task<bool> ChatAsync(string input, IShell shell)
string answer = data is null ? copilotResponse.Text : GenerateAnswer(data);
if (data?.PlaceholderSet is not null)
{
ArgPlaceholder = new ArgumentPlaceholder(input, data);
ArgPlaceholder = new ArgumentPlaceholder(input, data, _httpClient);
}

host.RenderFullResponse(answer);
Expand Down
13 changes: 6 additions & 7 deletions shell/agents/Microsoft.Azure.Agent/ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ internal class ChatSession : IDisposable
private readonly HttpClient _httpClient;
private readonly Dictionary<string, object> _flights;

internal ChatSession()
internal ChatSession(HttpClient httpClient)
{
_dl_secret = Environment.GetEnvironmentVariable("DL_SECRET");
_httpClient = new HttpClient();
_httpClient = httpClient;

// Keys and values for flights are from the portal request.
_flights = new Dictionary<string, object>()
Expand Down Expand Up @@ -199,14 +199,14 @@ private HttpRequestMessage PrepareForChat(string input)
text = input,
attachments = new object[] {
new {
contentType = "application/json",
contentType = Utils.JsonContentType,
name = "azurecopilot/clienthandlerdefinitions",
content = new {
clientHandlers = Array.Empty<object>()
}
},
new {
contentType = "application/json",
contentType = Utils.JsonContentType,
name = "azurecopilot/viewcontext",
content = new {
viewContext = new {
Expand All @@ -217,7 +217,7 @@ private HttpRequestMessage PrepareForChat(string input)
}
},
new {
contentType = "application/json",
contentType = Utils.JsonContentType,
name = "azurecopilot/flights",
content = new {
flights = _flights
Expand All @@ -227,7 +227,7 @@ private HttpRequestMessage PrepareForChat(string input)
};

var json = JsonSerializer.Serialize(requestData, Utils.JsonOptions);
var content = new StringContent(json, Encoding.UTF8, "application/json");
var content = new StringContent(json, Encoding.UTF8, Utils.JsonContentType);
var request = new HttpRequestMessage(HttpMethod.Post, _conversationUrl) { Content = content };

request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _token);
Expand Down Expand Up @@ -305,7 +305,6 @@ internal async Task<CopilotResponse> GetChatResponseAsync(string input, IStatusC

public void Dispose()
{
_httpClient.Dispose();
_copilotReceiver?.Dispose();
}
}
172 changes: 70 additions & 102 deletions shell/agents/Microsoft.Azure.Agent/DataRetriever.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Collections.Concurrent;
using System.ComponentModel;
using System.Diagnostics;
using System.Text;
using System.Text.Json;
using System.Text.RegularExpressions;
using AIShell.Abstraction;
Expand All @@ -9,11 +10,14 @@ namespace Microsoft.Azure.Agent;

internal class DataRetriever : IDisposable
{
private const string MetadataQueryTemplate = "{{\"command\":\"{0}\"}}";
private const string MetadataEndpoint = "https://cli-validation-tool-meta-qry.azurewebsites.net/api/command_metadata";

private static readonly Dictionary<string, NamingRule> s_azNamingRules;
private static readonly ConcurrentDictionary<string, Command> s_azStaticDataCache;
private static readonly ConcurrentDictionary<string, AzCLICommand> s_azStaticDataCache;

private readonly string _staticDataRoot;
private readonly Task _rootTask;
private readonly HttpClient _httpClient;
private readonly SemaphoreSlim _semaphore;
private readonly List<ArgumentPair> _placeholders;
private readonly Dictionary<string, ArgumentPair> _placeholderMap;
Expand Down Expand Up @@ -302,11 +306,11 @@ static DataRetriever()
s_azStaticDataCache = new(StringComparer.OrdinalIgnoreCase);
}

internal DataRetriever(ResponseData data)
internal DataRetriever(ResponseData data, HttpClient httpClient)
{
_stop = false;
_httpClient = httpClient;
_semaphore = new SemaphoreSlim(3, 3);
_staticDataRoot = @"E:\yard\tmp\az-cli-out\az";
_placeholders = new(capacity: data.PlaceholderSet.Count);
_placeholderMap = new(capacity: data.PlaceholderSet.Count);

Expand Down Expand Up @@ -453,31 +457,23 @@ private ArgumentInfo CreateArgInfo(ArgumentPair pair)
private List<string> GetArgValues(ArgumentPair pair)
{
// First, try to get static argument values if they exist.
bool hasCompleter = true;
string command = pair.Command;
if (!s_azStaticDataCache.TryGetValue(command, out Command commandData))

AzCLICommand commandData = s_azStaticDataCache.GetOrAdd(command, QueryForMetadata);
AzCLIParameter param = commandData?.FindParameter(pair.Parameter);

if (param is not null)
{
string[] cmdElements = command.Split(' ', StringSplitOptions.RemoveEmptyEntries);
string dirPath = _staticDataRoot;
for (int i = 1; i < cmdElements.Length - 1; i++)
if (param.Choices?.Count > 0)
{
dirPath = Path.Combine(dirPath, cmdElements[i]);
return param.Choices;
}

string filePath = Path.Combine(dirPath, cmdElements[^1] + ".json");
commandData = File.Exists(filePath)
? JsonSerializer.Deserialize<Command>(File.OpenRead(filePath))
: null;
s_azStaticDataCache.TryAdd(command, commandData);
}

Option option = commandData?.FindOption(pair.Parameter);
List<string> staticValues = option?.Arguments;
if (staticValues?.Count > 0)
{
return staticValues;
hasCompleter = param.HasCompleter;
}

if (_stop) { return null; }
if (_stop || !hasCompleter) { return null; }

// Then, try to get dynamic argument values using AzCLI tab completion.
string commandLine = $"{pair.Command} {pair.Parameter} ";
Expand Down Expand Up @@ -551,6 +547,42 @@ private List<string> GetArgValues(ArgumentPair pair)
}
}

private AzCLICommand QueryForMetadata(string azCommand)
{
AzCLICommand command = null;
var reqBody = new StringContent(string.Format(MetadataQueryTemplate, azCommand), Encoding.UTF8, Utils.JsonContentType);
var request = new HttpRequestMessage(HttpMethod.Get, MetadataEndpoint) { Content = reqBody };

try
{
using var cts = new CancellationTokenSource(1200);
var response = _httpClient.Send(request, HttpCompletionOption.ResponseHeadersRead, cts.Token);

if (response.IsSuccessStatusCode)
{
using Stream stream = response.Content.ReadAsStream(cts.Token);
using JsonDocument document = JsonDocument.Parse(stream);

JsonElement root = document.RootElement;
if (root.TryGetProperty("data", out JsonElement data) &&
data.TryGetProperty("metadata", out JsonElement metadata))
{
command = metadata.Deserialize<AzCLICommand>(Utils.JsonOptions);
}
}
else
{
// TODO: telemetry.
}
}
catch (Exception)
{
// TODO: telemetry.
}

return command;
}

internal (string command, string parameter) GetMappedCommand(string placeholderName)
{
if (_placeholderMap.TryGetValue(placeholderName, out ArgumentPair pair))
Expand Down Expand Up @@ -585,6 +617,22 @@ public void Dispose()
_rootTask.Wait();
_semaphore.Dispose();
}

internal static void WarmUpMetadataService(HttpClient httpClient)
{
// Send a request to the AzCLI metadata service to warm up the service (code start is slow).
// We query for the command 'az sql server list' which only has 2 parameters,
// so it should cause minimum processing on the server side.
HttpRequestMessage request = new(HttpMethod.Get, MetadataEndpoint)
{
Content = new StringContent(
"{\"command\":\"az sql server list\"}",
Encoding.UTF8,
Utils.JsonContentType)
};

_ = httpClient.SendAsync(request, HttpCompletionOption.ResponseHeadersRead);
}
}

internal class ArgumentPair
Expand Down Expand Up @@ -703,83 +751,3 @@ internal bool TryMatchName(string name, out string prodName, out string envName)
return false;
}
}

public class Option
{
public string Name { get; }
public string[] Alias { get; }
public string[] Short { get; }
public string Attribute { get; }
public string Description { get; set; }
public List<string> Arguments { get; set; }

public Option(string name, string description, string[] alias, string[] @short, string attribute, List<string> arguments)
{
ArgumentException.ThrowIfNullOrEmpty(name);
ArgumentException.ThrowIfNullOrEmpty(description);

Name = name;
Alias = alias;
Short = @short;
Attribute = attribute;
Description = description;
Arguments = arguments;
}
}

public sealed class Command
{
public List<Option> Options { get; }
public string Examples { get; }
public string Name { get; }
public string Description { get; }

public Command(string name, string description, List<Option> options, string examples)
{
ArgumentException.ThrowIfNullOrEmpty(name);
ArgumentException.ThrowIfNullOrEmpty(description);
ArgumentNullException.ThrowIfNull(options);

Options = options;
Examples = examples;
Name = name;
Description = description;
}

public Option FindOption(string name)
{
foreach (Option option in Options)
{
if (name.StartsWith("--"))
{
if (string.Equals(option.Name, name, StringComparison.OrdinalIgnoreCase))
{
return option;
}

if (option.Alias is not null)
{
foreach (string alias in option.Alias)
{
if (string.Equals(alias, name, StringComparison.OrdinalIgnoreCase))
{
return option;
}
}
}
}
else if (option.Short is not null)
{
foreach (string s in option.Short)
{
if (string.Equals(s, name, StringComparison.OrdinalIgnoreCase))
{
return option;
}
}
}
}

return null;
}
}
35 changes: 33 additions & 2 deletions shell/agents/Microsoft.Azure.Agent/Schema.cs
Original file line number Diff line number Diff line change
Expand Up @@ -224,19 +224,50 @@ internal class ResponseData

internal class ArgumentPlaceholder
{
internal ArgumentPlaceholder(string query, ResponseData data)
internal ArgumentPlaceholder(string query, ResponseData data, HttpClient httpClient)
{
ArgumentException.ThrowIfNullOrEmpty(query);
ArgumentNullException.ThrowIfNull(data);

Query = query;
ResponseData = data;
DataRetriever = new(data);
DataRetriever = new(data, httpClient);
}

public string Query { get; set; }
public ResponseData ResponseData { get; set; }
public DataRetriever DataRetriever { get; }
}

internal class AzCLIParameter
{
public List<string> Options { get; set; }
public List<string> Choices { get; set; }
public bool Required { get; set; }

[JsonPropertyName("has_completer")]
public bool HasCompleter { get; set; }
}

internal class AzCLICommand
{
public List<AzCLIParameter> Parameters { get; set; }

public AzCLIParameter FindParameter(string name)
{
foreach (var param in Parameters)
{
foreach (var option in param.Options)
{
if (option.Equals(name, StringComparison.OrdinalIgnoreCase))
{
return param;
}
}
}

return null;
}
}

#endregion
2 changes: 2 additions & 0 deletions shell/agents/Microsoft.Azure.Agent/Utils.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ namespace Microsoft.Azure.Agent;

internal static class Utils
{
internal const string JsonContentType = "application/json";

private static readonly JsonSerializerOptions s_jsonOptions;
private static readonly JsonSerializerOptions s_humanReadableOptions;

Expand Down