Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 25 additions & 3 deletions shell/agents/Microsoft.Azure.Agent/AzureAgent.cs
Original file line number Diff line number Diff line change
Expand Up @@ -165,10 +165,26 @@ public async Task RefreshChatAsync(IShell shell, bool force)
{
host.WriteErrorLine("Operation cancelled. Please run '/refresh' to start a new conversation.");
}
catch (CredentialUnavailableException)
catch (TokenRequestException e)
{
host.WriteErrorLine($"Failed to start a chat session: Access token not available.");
host.WriteErrorLine($"The '{Name}' agent depends on the Azure CLI credential to acquire access token. Please run 'az login' from a command-line shell to setup account.");
if (e.UserUnauthorized)
{
host.WriteLine("Sorry, you are not authorized to access Azure Copilot services.");
host.WriteLine($"Details: {e.Message}");
return;
}

Exception inner = e.InnerException;
if (inner is CredentialUnavailableException)
{
host.WriteErrorLine($"Failed to start a chat session: Access token not available.");
host.WriteErrorLine($"The '{Name}' agent depends on the Azure CLI credential to acquire access token. Please run 'az login' from a command-line shell to setup account.");
host.WriteErrorLine("Once you've successfully logged in, please run '/refresh' to start a new conversation");
return;
}

host.WriteErrorLine(e.Message);
host.WriteErrorLine("Please try '/refresh' to start a new conversation.");
}
catch (Exception e)
{
Expand All @@ -189,6 +205,12 @@ public async Task<bool> ChatAsync(string input, IShell shell)
return true;
}

if (!_chatSession.UserAuthorized)
{
host.WriteLine("\nSorry, you are not authorized to access Azure Copilot services.\n");
return true;
}

try
{
string query = $"{input}\n\n---\n\n{_instructions}";
Expand Down
211 changes: 96 additions & 115 deletions shell/agents/Microsoft.Azure.Agent/ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,32 @@
using System.Text.Json.Nodes;

using AIShell.Abstraction;
using Azure.Core;
using Azure.Identity;
using Serilog;

namespace Microsoft.Azure.Agent;

internal class ChatSession : IDisposable
{
private const string ACCESS_URL = "https://copilotweb.production.portalrp.azure.com/api/access?api-version=2024-09-01";
private const string DL_TOKEN_URL = "https://copilotweb.production.portalrp.azure.com/api/conversations/start?api-version=2024-11-15";
private const string REFRESH_TOKEN_URL = "https://directline.botframework.com/v3/directline/tokens/refresh";
private const string CONVERSATION_URL = "https://directline.botframework.com/v3/directline/conversations";

private string _token;
internal bool UserAuthorized { get; private set; }

private string _streamUrl;
private string _conversationId;
private string _conversationUrl;
private DateTime _expireOn;
private UserDirectLineToken _directLineToken;
private AzureCopilotReceiver _copilotReceiver;

private readonly HttpClient _httpClient;
private readonly UserAccessToken _accessToken;
private readonly Dictionary<string, object> _flights;

internal ChatSession(HttpClient httpClient)
{
_httpClient = httpClient;
_accessToken = new UserAccessToken();

// Keys and values for flights are from the portal request.
_flights = new Dictionary<string, object>()
Expand Down Expand Up @@ -63,179 +64,146 @@ internal ChatSession(HttpClient httpClient)

internal async Task<string> RefreshAsync(IStatusContext context, bool force, CancellationToken cancellationToken)
{
if (_token is not null)
if (_directLineToken is not null)
{
if (force)
{
// End the existing conversation.
context.Status("Ending current chat ...");
EndConversation();
Reset();
EndCurrentConversation();
}
else
{
try
{
context.Status("Refreshing token ...");
await RenewTokenAsync(cancellationToken);
context.Status("Refreshing access token ...");
await _accessToken.CreateOrRenewTokenAsync(cancellationToken);

context.Status("Refreshing DirectLine token ...");
await _directLineToken.RenewTokenAsync(_httpClient, cancellationToken);

// Tokens successfully refreshed.
return null;
}
catch (OperationCanceledException)
{
throw;
}
catch (Exception)
{
// Refreshing failed. We will create a new chat session.
}
}
}

_token = await GenerateTokenAsync(context, cancellationToken);
return await OpenConversationAsync(context, cancellationToken);
return await SetupNewChat(context, cancellationToken);
}

private void Reset()
{
_token = null;
_streamUrl = null;
_conversationId = null;
_conversationUrl = null;
_expireOn = DateTime.MinValue;
_directLineToken = null;

_accessToken.Reset();
_copilotReceiver?.Dispose();
_copilotReceiver = null;
}

private async Task<string> GenerateTokenAsync(IStatusContext context, CancellationToken cancellationToken)
private async Task<string> SetupNewChat(IStatusContext context, CancellationToken cancellationToken)
{
try
{
context.Status("Get Azure CLI login token ...");
// Get an access token from the AzCLI login, using the specific audience guid.
AccessToken accessToken = await new AzureCliCredential()
.GetTokenAsync(
new TokenRequestContext(["7000789f-b583-4714-ab18-aef39213018a/.default"]),
cancellationToken);
await _accessToken.CreateOrRenewTokenAsync(cancellationToken);

context.Status("Request for DirectLine token ...");
StringContent content = new("{\"conversationType\": \"Chat\"}", Encoding.UTF8, Utils.JsonContentType);
HttpRequestMessage request = new(HttpMethod.Post, DL_TOKEN_URL) { Content = content };
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken.Token);
context.Status("Check Copilot authorization ...");
await CheckAuthorizationAsync(cancellationToken);

HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken);
response.EnsureSuccessStatusCode();
context.Status("Request for DirectLine token ...");
await GetInitialDLTokenAsync(cancellationToken);

using Stream stream = await response.Content.ReadAsStreamAsync(cancellationToken);
var dlToken = JsonSerializer.Deserialize<DirectLineToken>(stream, Utils.JsonOptions);
return dlToken.DirectLine.Token;
context.Status("Start a new chat session ...");
return await OpenConversationAsync(cancellationToken);
}
catch (Exception e)
{
if (e is not OperationCanceledException)
if (e is not OperationCanceledException and TokenRequestException)
{
Telemetry.Trace(AzTrace.Exception("Failed to generate the initial DL token."), e);
Telemetry.Trace(AzTrace.Exception("Failed to setup a new chat session."), e);
}

Reset();
throw;
}
}

private async Task<string> OpenConversationAsync(IStatusContext context, CancellationToken cancellationToken)
private async Task CheckAuthorizationAsync(CancellationToken cancellationToken)
{
try
{
context.Status("Start a new chat session ...");
HttpRequestMessage request = new(HttpMethod.Post, CONVERSATION_URL);
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _token);
HttpRequestMessage request = new(HttpMethod.Get, ACCESS_URL);
request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json"));
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _accessToken.Token);

HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken);
response.EnsureSuccessStatusCode();

using Stream content = await response.Content.ReadAsStreamAsync(cancellationToken);
SessionPayload spl = JsonSerializer.Deserialize<SessionPayload>(content, Utils.JsonOptions);

_token = spl.Token;
_conversationId = spl.ConversationId;
_conversationUrl = $"{CONVERSATION_URL}/{_conversationId}/activities";
_streamUrl = spl.StreamUrl;
_expireOn = DateTime.UtcNow.AddSeconds(spl.ExpiresIn);
_copilotReceiver = await AzureCopilotReceiver.CreateAsync(_streamUrl);
HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken);
await response.EnsureSuccessStatusCodeForTokenRequest("Failed to check Copilot authorization.");

Log.Debug("[ChatSession] Conversation started. Id: {0}", _conversationId);
using Stream stream = await response.Content.ReadAsStreamAsync(cancellationToken);
var permission = JsonSerializer.Deserialize<CopilotPermission>(stream, Utils.JsonOptions);
UserAuthorized = permission.Authorized;

while (true)
{
CopilotActivity activity = _copilotReceiver.Take(cancellationToken);
if (activity.IsMessage && activity.IsFromCopilot && _copilotReceiver.Watermark is 0)
{
activity.ExtractMetadata(out _, out ConversationState conversationState);
int chatNumber = conversationState.DailyConversationNumber;
int requestNumber = conversationState.TurnNumber;
return $"{activity.Text}\nThis is chat #{chatNumber}, request #{requestNumber}.\n";
}
}
}
catch (Exception e)
if (!UserAuthorized)
{
if (e is not OperationCanceledException)
{
Telemetry.Trace(AzTrace.Exception("Failed to open conversation with the initial DL token."), e);
}

Reset();
throw;
string message = $"Access token not authorized to access Azure Copilot. {permission.Message}";
Telemetry.Trace(AzTrace.Exception(message));
throw new TokenRequestException(message) { UserUnauthorized = true };
}
}

private TokenHealth CheckDLTokenHealth()
private async Task GetInitialDLTokenAsync(CancellationToken cancellationToken)
{
ArgumentNullException.ThrowIfNull(_token, nameof(_token));
StringContent content = new("{\"conversationType\": \"Chat\"}", Encoding.UTF8, Utils.JsonContentType);
HttpRequestMessage request = new(HttpMethod.Post, DL_TOKEN_URL) { Content = content };
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _accessToken.Token);

var now = DateTime.UtcNow;
if (now > _expireOn || now.AddMinutes(2) >= _expireOn)
{
return TokenHealth.Expired;
}

if (now.AddMinutes(10) < _expireOn)
{
return TokenHealth.Good;
}
HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken);
await response.EnsureSuccessStatusCodeForTokenRequest("Failed to generate the initial DL token.");

return TokenHealth.TimeToRefresh;
using Stream stream = await response.Content.ReadAsStreamAsync(cancellationToken);
var dlToken = JsonSerializer.Deserialize<DirectLineToken>(stream, Utils.JsonOptions);
_directLineToken = new UserDirectLineToken(dlToken.DirectLine.Token, dlToken.DirectLine.TokenExpiryTimeInSeconds);
}

private async Task RenewTokenAsync(CancellationToken cancellationToken)
private async Task<string> OpenConversationAsync(CancellationToken cancellationToken)
{
TokenHealth health = CheckDLTokenHealth();
if (health is TokenHealth.Expired)
{
Reset();
throw new TokenRequestException("The chat session has expired. Please start a new chat session.");
}
HttpRequestMessage request = new(HttpMethod.Post, CONVERSATION_URL);
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _directLineToken.Token);

if (health is TokenHealth.Good)
{
return;
}
HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken);
await response.EnsureSuccessStatusCodeForTokenRequest("Failed to open an conversation.");

try
{
HttpRequestMessage request = new(HttpMethod.Post, REFRESH_TOKEN_URL);
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _token);
using Stream content = await response.Content.ReadAsStreamAsync(cancellationToken);
SessionPayload spl = JsonSerializer.Deserialize<SessionPayload>(content, Utils.JsonOptions);

var response = await _httpClient.SendAsync(request, cancellationToken);
response.EnsureSuccessStatusCode();
_conversationId = spl.ConversationId;
_conversationUrl = $"{CONVERSATION_URL}/{_conversationId}/activities";
_directLineToken = new UserDirectLineToken(spl.Token, spl.ExpiresIn);
_streamUrl = spl.StreamUrl;
_copilotReceiver = await AzureCopilotReceiver.CreateAsync(_streamUrl);

using Stream content = await response.Content.ReadAsStreamAsync(cancellationToken);
RefreshDLToken dlToken = JsonSerializer.Deserialize<RefreshDLToken>(content, Utils.JsonOptions);
Log.Debug("[ChatSession] Conversation started. Id: {0}", _conversationId);

_token = dlToken.Token;
_expireOn = DateTime.UtcNow.AddSeconds(dlToken.ExpiresIn);
}
catch (Exception e) when (e is not OperationCanceledException)
while (true)
{
Reset();
Telemetry.Trace(AzTrace.Exception("Failed to refresh the DL token."), e);
throw new TokenRequestException($"Failed to refresh the 'DirectLine' token: {e.Message}.", e);
CopilotActivity activity = _copilotReceiver.Take(cancellationToken);
if (activity.IsMessage && activity.IsFromCopilot && _copilotReceiver.Watermark is 0)
{
activity.ExtractMetadata(out _, out ConversationState conversationState);
int chatNumber = conversationState.DailyConversationNumber;
int requestNumber = conversationState.TurnNumber;
return $"{activity.Text}\nThis is chat #{chatNumber}, request #{requestNumber}.\n";
}
}
}

Expand Down Expand Up @@ -280,7 +248,7 @@ private HttpRequestMessage PrepareForChat(string input)
var content = new StringContent(json, Encoding.UTF8, Utils.JsonContentType);
var request = new HttpRequestMessage(HttpMethod.Post, _conversationUrl) { Content = content };

request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _token);
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _directLineToken.Token);
// This header is for server side telemetry to identify where the request comes from.
request.Headers.Add("ClientType", "AIShell");
return request;
Expand All @@ -299,9 +267,9 @@ private async Task<string> SendQueryToCopilot(string input, CancellationToken ca
return contentObj["id"].ToString();
}

private void EndConversation()
private void EndCurrentConversation()
{
if (_token is null || CheckDLTokenHealth() is TokenHealth.Expired)
if (_directLineToken is null || _directLineToken.CheckTokenHealth() is TokenHealth.Expired)
{
// Chat session already expired, no need to send request to end the conversation.
return;
Expand All @@ -310,16 +278,24 @@ private void EndConversation()
var content = new StringContent("{\"type\":\"endOfConversation\",\"from\":{\"id\":\"user\"}}", Encoding.UTF8, Utils.JsonContentType);
var request = new HttpRequestMessage(HttpMethod.Post, _conversationUrl) { Content = content };

request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _token);
request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _directLineToken.Token);
_httpClient.Send(request, HttpCompletionOption.ResponseHeadersRead, CancellationToken.None);
}

internal async Task<CopilotResponse> GetChatResponseAsync(string input, IStatusContext context, CancellationToken cancellationToken)
{
if (_directLineToken is null)
{
throw new TokenRequestException("A chat session hasn't been setup yet.");
}

try
{
context?.Status("Refreshing Token ...");
await RenewTokenAsync(cancellationToken);
context.Status("Refreshing access token ...");
await _accessToken.CreateOrRenewTokenAsync(cancellationToken);

context.Status("Refreshing DirectLine token ...");
await _directLineToken.RenewTokenAsync(_httpClient, cancellationToken);

context?.Status("Sending query ...");
string activityId = await SendQueryToCopilot(input, cancellationToken);
Expand Down Expand Up @@ -370,11 +346,16 @@ internal async Task<CopilotResponse> GetChatResponseAsync(string input, IStatusC
// TODO: we may need to notify azure copilot somehow about the cancellation.
return null;
}
catch (TokenRequestException)
{
Reset();
throw;
}
}

public void Dispose()
{
EndConversation();
EndCurrentConversation();
_copilotReceiver?.Dispose();
}
}
Loading