From 59864edd2a744f4c79af3705b31ea5961a39b382 Mon Sep 17 00:00:00 2001 From: Dongbo Wang Date: Wed, 30 Oct 2024 16:36:09 -0700 Subject: [PATCH] Add the check for Copilot access --- .../Microsoft.Azure.Agent/AzureAgent.cs | 28 ++- .../Microsoft.Azure.Agent/ChatSession.cs | 211 ++++++++---------- shell/agents/Microsoft.Azure.Agent/Schema.cs | 148 ++++++++++++ shell/agents/Microsoft.Azure.Agent/Utils.cs | 16 ++ 4 files changed, 285 insertions(+), 118 deletions(-) diff --git a/shell/agents/Microsoft.Azure.Agent/AzureAgent.cs b/shell/agents/Microsoft.Azure.Agent/AzureAgent.cs index ce694385..4a8f3cf6 100644 --- a/shell/agents/Microsoft.Azure.Agent/AzureAgent.cs +++ b/shell/agents/Microsoft.Azure.Agent/AzureAgent.cs @@ -165,10 +165,26 @@ public async Task RefreshChatAsync(IShell shell, bool force) { host.WriteErrorLine("Operation cancelled. Please run '/refresh' to start a new conversation."); } - catch (CredentialUnavailableException) + catch (TokenRequestException e) { - host.WriteErrorLine($"Failed to start a chat session: Access token not available."); - host.WriteErrorLine($"The '{Name}' agent depends on the Azure CLI credential to acquire access token. Please run 'az login' from a command-line shell to setup account."); + if (e.UserUnauthorized) + { + host.WriteLine("Sorry, you are not authorized to access Azure Copilot services."); + host.WriteLine($"Details: {e.Message}"); + return; + } + + Exception inner = e.InnerException; + if (inner is CredentialUnavailableException) + { + host.WriteErrorLine($"Failed to start a chat session: Access token not available."); + host.WriteErrorLine($"The '{Name}' agent depends on the Azure CLI credential to acquire access token. Please run 'az login' from a command-line shell to setup account."); + host.WriteErrorLine("Once you've successfully logged in, please run '/refresh' to start a new conversation"); + return; + } + + host.WriteErrorLine(e.Message); + host.WriteErrorLine("Please try '/refresh' to start a new conversation."); } catch (Exception e) { @@ -189,6 +205,12 @@ public async Task ChatAsync(string input, IShell shell) return true; } + if (!_chatSession.UserAuthorized) + { + host.WriteLine("\nSorry, you are not authorized to access Azure Copilot services.\n"); + return true; + } + try { string query = $"{input}\n\n---\n\n{_instructions}"; diff --git a/shell/agents/Microsoft.Azure.Agent/ChatSession.cs b/shell/agents/Microsoft.Azure.Agent/ChatSession.cs index 4934a884..93a550f8 100644 --- a/shell/agents/Microsoft.Azure.Agent/ChatSession.cs +++ b/shell/agents/Microsoft.Azure.Agent/ChatSession.cs @@ -5,31 +5,32 @@ using System.Text.Json.Nodes; using AIShell.Abstraction; -using Azure.Core; -using Azure.Identity; using Serilog; namespace Microsoft.Azure.Agent; internal class ChatSession : IDisposable { + private const string ACCESS_URL = "https://copilotweb.production.portalrp.azure.com/api/access?api-version=2024-09-01"; private const string DL_TOKEN_URL = "https://copilotweb.production.portalrp.azure.com/api/conversations/start?api-version=2024-11-15"; - private const string REFRESH_TOKEN_URL = "https://directline.botframework.com/v3/directline/tokens/refresh"; private const string CONVERSATION_URL = "https://directline.botframework.com/v3/directline/conversations"; - private string _token; + internal bool UserAuthorized { get; private set; } + private string _streamUrl; private string _conversationId; private string _conversationUrl; - private DateTime _expireOn; + private UserDirectLineToken _directLineToken; private AzureCopilotReceiver _copilotReceiver; private readonly HttpClient _httpClient; + private readonly UserAccessToken _accessToken; private readonly Dictionary _flights; internal ChatSession(HttpClient httpClient) { _httpClient = httpClient; + _accessToken = new UserAccessToken(); // Keys and values for flights are from the portal request. _flights = new Dictionary() @@ -63,23 +64,31 @@ internal ChatSession(HttpClient httpClient) internal async Task RefreshAsync(IStatusContext context, bool force, CancellationToken cancellationToken) { - if (_token is not null) + if (_directLineToken is not null) { if (force) { // End the existing conversation. context.Status("Ending current chat ..."); - EndConversation(); - Reset(); + EndCurrentConversation(); } else { try { - context.Status("Refreshing token ..."); - await RenewTokenAsync(cancellationToken); + context.Status("Refreshing access token ..."); + await _accessToken.CreateOrRenewTokenAsync(cancellationToken); + + context.Status("Refreshing DirectLine token ..."); + await _directLineToken.RenewTokenAsync(_httpClient, cancellationToken); + + // Tokens successfully refreshed. return null; } + catch (OperationCanceledException) + { + throw; + } catch (Exception) { // Refreshing failed. We will create a new chat session. @@ -87,50 +96,43 @@ internal async Task RefreshAsync(IStatusContext context, bool force, Can } } - _token = await GenerateTokenAsync(context, cancellationToken); - return await OpenConversationAsync(context, cancellationToken); + return await SetupNewChat(context, cancellationToken); } private void Reset() { - _token = null; _streamUrl = null; _conversationId = null; _conversationUrl = null; - _expireOn = DateTime.MinValue; + _directLineToken = null; + _accessToken.Reset(); _copilotReceiver?.Dispose(); _copilotReceiver = null; } - private async Task GenerateTokenAsync(IStatusContext context, CancellationToken cancellationToken) + private async Task SetupNewChat(IStatusContext context, CancellationToken cancellationToken) { try { context.Status("Get Azure CLI login token ..."); // Get an access token from the AzCLI login, using the specific audience guid. - AccessToken accessToken = await new AzureCliCredential() - .GetTokenAsync( - new TokenRequestContext(["7000789f-b583-4714-ab18-aef39213018a/.default"]), - cancellationToken); + await _accessToken.CreateOrRenewTokenAsync(cancellationToken); - context.Status("Request for DirectLine token ..."); - StringContent content = new("{\"conversationType\": \"Chat\"}", Encoding.UTF8, Utils.JsonContentType); - HttpRequestMessage request = new(HttpMethod.Post, DL_TOKEN_URL) { Content = content }; - request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", accessToken.Token); + context.Status("Check Copilot authorization ..."); + await CheckAuthorizationAsync(cancellationToken); - HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken); - response.EnsureSuccessStatusCode(); + context.Status("Request for DirectLine token ..."); + await GetInitialDLTokenAsync(cancellationToken); - using Stream stream = await response.Content.ReadAsStreamAsync(cancellationToken); - var dlToken = JsonSerializer.Deserialize(stream, Utils.JsonOptions); - return dlToken.DirectLine.Token; + context.Status("Start a new chat session ..."); + return await OpenConversationAsync(cancellationToken); } catch (Exception e) { - if (e is not OperationCanceledException) + if (e is not OperationCanceledException and TokenRequestException) { - Telemetry.Trace(AzTrace.Exception("Failed to generate the initial DL token."), e); + Telemetry.Trace(AzTrace.Exception("Failed to setup a new chat session."), e); } Reset(); @@ -138,104 +140,70 @@ private async Task GenerateTokenAsync(IStatusContext context, Cancellati } } - private async Task OpenConversationAsync(IStatusContext context, CancellationToken cancellationToken) + private async Task CheckAuthorizationAsync(CancellationToken cancellationToken) { - try - { - context.Status("Start a new chat session ..."); - HttpRequestMessage request = new(HttpMethod.Post, CONVERSATION_URL); - request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _token); + HttpRequestMessage request = new(HttpMethod.Get, ACCESS_URL); + request.Headers.Accept.Add(new MediaTypeWithQualityHeaderValue("application/json")); + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _accessToken.Token); - HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken); - response.EnsureSuccessStatusCode(); - - using Stream content = await response.Content.ReadAsStreamAsync(cancellationToken); - SessionPayload spl = JsonSerializer.Deserialize(content, Utils.JsonOptions); - - _token = spl.Token; - _conversationId = spl.ConversationId; - _conversationUrl = $"{CONVERSATION_URL}/{_conversationId}/activities"; - _streamUrl = spl.StreamUrl; - _expireOn = DateTime.UtcNow.AddSeconds(spl.ExpiresIn); - _copilotReceiver = await AzureCopilotReceiver.CreateAsync(_streamUrl); + HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken); + await response.EnsureSuccessStatusCodeForTokenRequest("Failed to check Copilot authorization."); - Log.Debug("[ChatSession] Conversation started. Id: {0}", _conversationId); + using Stream stream = await response.Content.ReadAsStreamAsync(cancellationToken); + var permission = JsonSerializer.Deserialize(stream, Utils.JsonOptions); + UserAuthorized = permission.Authorized; - while (true) - { - CopilotActivity activity = _copilotReceiver.Take(cancellationToken); - if (activity.IsMessage && activity.IsFromCopilot && _copilotReceiver.Watermark is 0) - { - activity.ExtractMetadata(out _, out ConversationState conversationState); - int chatNumber = conversationState.DailyConversationNumber; - int requestNumber = conversationState.TurnNumber; - return $"{activity.Text}\nThis is chat #{chatNumber}, request #{requestNumber}.\n"; - } - } - } - catch (Exception e) + if (!UserAuthorized) { - if (e is not OperationCanceledException) - { - Telemetry.Trace(AzTrace.Exception("Failed to open conversation with the initial DL token."), e); - } - - Reset(); - throw; + string message = $"Access token not authorized to access Azure Copilot. {permission.Message}"; + Telemetry.Trace(AzTrace.Exception(message)); + throw new TokenRequestException(message) { UserUnauthorized = true }; } } - private TokenHealth CheckDLTokenHealth() + private async Task GetInitialDLTokenAsync(CancellationToken cancellationToken) { - ArgumentNullException.ThrowIfNull(_token, nameof(_token)); + StringContent content = new("{\"conversationType\": \"Chat\"}", Encoding.UTF8, Utils.JsonContentType); + HttpRequestMessage request = new(HttpMethod.Post, DL_TOKEN_URL) { Content = content }; + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _accessToken.Token); - var now = DateTime.UtcNow; - if (now > _expireOn || now.AddMinutes(2) >= _expireOn) - { - return TokenHealth.Expired; - } - - if (now.AddMinutes(10) < _expireOn) - { - return TokenHealth.Good; - } + HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken); + await response.EnsureSuccessStatusCodeForTokenRequest("Failed to generate the initial DL token."); - return TokenHealth.TimeToRefresh; + using Stream stream = await response.Content.ReadAsStreamAsync(cancellationToken); + var dlToken = JsonSerializer.Deserialize(stream, Utils.JsonOptions); + _directLineToken = new UserDirectLineToken(dlToken.DirectLine.Token, dlToken.DirectLine.TokenExpiryTimeInSeconds); } - private async Task RenewTokenAsync(CancellationToken cancellationToken) + private async Task OpenConversationAsync(CancellationToken cancellationToken) { - TokenHealth health = CheckDLTokenHealth(); - if (health is TokenHealth.Expired) - { - Reset(); - throw new TokenRequestException("The chat session has expired. Please start a new chat session."); - } + HttpRequestMessage request = new(HttpMethod.Post, CONVERSATION_URL); + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _directLineToken.Token); - if (health is TokenHealth.Good) - { - return; - } + HttpResponseMessage response = await _httpClient.SendAsync(request, cancellationToken); + await response.EnsureSuccessStatusCodeForTokenRequest("Failed to open an conversation."); - try - { - HttpRequestMessage request = new(HttpMethod.Post, REFRESH_TOKEN_URL); - request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _token); + using Stream content = await response.Content.ReadAsStreamAsync(cancellationToken); + SessionPayload spl = JsonSerializer.Deserialize(content, Utils.JsonOptions); - var response = await _httpClient.SendAsync(request, cancellationToken); - response.EnsureSuccessStatusCode(); + _conversationId = spl.ConversationId; + _conversationUrl = $"{CONVERSATION_URL}/{_conversationId}/activities"; + _directLineToken = new UserDirectLineToken(spl.Token, spl.ExpiresIn); + _streamUrl = spl.StreamUrl; + _copilotReceiver = await AzureCopilotReceiver.CreateAsync(_streamUrl); - using Stream content = await response.Content.ReadAsStreamAsync(cancellationToken); - RefreshDLToken dlToken = JsonSerializer.Deserialize(content, Utils.JsonOptions); + Log.Debug("[ChatSession] Conversation started. Id: {0}", _conversationId); - _token = dlToken.Token; - _expireOn = DateTime.UtcNow.AddSeconds(dlToken.ExpiresIn); - } - catch (Exception e) when (e is not OperationCanceledException) + while (true) { - Reset(); - Telemetry.Trace(AzTrace.Exception("Failed to refresh the DL token."), e); - throw new TokenRequestException($"Failed to refresh the 'DirectLine' token: {e.Message}.", e); + CopilotActivity activity = _copilotReceiver.Take(cancellationToken); + if (activity.IsMessage && activity.IsFromCopilot && _copilotReceiver.Watermark is 0) + { + activity.ExtractMetadata(out _, out ConversationState conversationState); + int chatNumber = conversationState.DailyConversationNumber; + int requestNumber = conversationState.TurnNumber; + return $"{activity.Text}\nThis is chat #{chatNumber}, request #{requestNumber}.\n"; + } } } @@ -280,7 +248,7 @@ private HttpRequestMessage PrepareForChat(string input) var content = new StringContent(json, Encoding.UTF8, Utils.JsonContentType); var request = new HttpRequestMessage(HttpMethod.Post, _conversationUrl) { Content = content }; - request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _token); + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _directLineToken.Token); // This header is for server side telemetry to identify where the request comes from. request.Headers.Add("ClientType", "AIShell"); return request; @@ -299,9 +267,9 @@ private async Task SendQueryToCopilot(string input, CancellationToken ca return contentObj["id"].ToString(); } - private void EndConversation() + private void EndCurrentConversation() { - if (_token is null || CheckDLTokenHealth() is TokenHealth.Expired) + if (_directLineToken is null || _directLineToken.CheckTokenHealth() is TokenHealth.Expired) { // Chat session already expired, no need to send request to end the conversation. return; @@ -310,16 +278,24 @@ private void EndConversation() var content = new StringContent("{\"type\":\"endOfConversation\",\"from\":{\"id\":\"user\"}}", Encoding.UTF8, Utils.JsonContentType); var request = new HttpRequestMessage(HttpMethod.Post, _conversationUrl) { Content = content }; - request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _token); + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _directLineToken.Token); _httpClient.Send(request, HttpCompletionOption.ResponseHeadersRead, CancellationToken.None); } internal async Task GetChatResponseAsync(string input, IStatusContext context, CancellationToken cancellationToken) { + if (_directLineToken is null) + { + throw new TokenRequestException("A chat session hasn't been setup yet."); + } + try { - context?.Status("Refreshing Token ..."); - await RenewTokenAsync(cancellationToken); + context.Status("Refreshing access token ..."); + await _accessToken.CreateOrRenewTokenAsync(cancellationToken); + + context.Status("Refreshing DirectLine token ..."); + await _directLineToken.RenewTokenAsync(_httpClient, cancellationToken); context?.Status("Sending query ..."); string activityId = await SendQueryToCopilot(input, cancellationToken); @@ -370,11 +346,16 @@ internal async Task GetChatResponseAsync(string input, IStatusC // TODO: we may need to notify azure copilot somehow about the cancellation. return null; } + catch (TokenRequestException) + { + Reset(); + throw; + } } public void Dispose() { - EndConversation(); + EndCurrentConversation(); _copilotReceiver?.Dispose(); } } diff --git a/shell/agents/Microsoft.Azure.Agent/Schema.cs b/shell/agents/Microsoft.Azure.Agent/Schema.cs index a88b08c0..2575e815 100644 --- a/shell/agents/Microsoft.Azure.Agent/Schema.cs +++ b/shell/agents/Microsoft.Azure.Agent/Schema.cs @@ -1,8 +1,12 @@ +using System.Net.Http.Headers; using System.Text; using System.Text.Json; using System.Text.Json.Nodes; using System.Text.Json.Serialization; + using AIShell.Abstraction; +using Azure.Core; +using Azure.Identity; namespace Microsoft.Azure.Agent; @@ -58,6 +62,12 @@ internal enum TokenHealth Expired } +internal class CopilotPermission +{ + public bool Authorized { get; set; } + public string Message { get; set; } +} + internal class RefreshDLToken { public string ConversationId { get; set; } @@ -340,3 +350,141 @@ public AzCLIParameter FindParameter(string name) } #endregion + +#region token wrappers + +internal class UserAccessToken +{ + private readonly TokenRequestContext _tokenContext; + private AccessToken? _accessToken; + + /// + /// The access token. + /// + internal string Token => _accessToken?.Token; + + /// + /// Initialize an instance with the proper token request context. + /// + internal UserAccessToken() + { + _tokenContext = new TokenRequestContext(["7000789f-b583-4714-ab18-aef39213018a/.default"]); + } + + /// + /// Create an access token, or renew an existing token. + /// + internal async Task CreateOrRenewTokenAsync(CancellationToken cancellationToken) + { + try + { + bool needRefresh = !_accessToken.HasValue; + if (!needRefresh) + { + needRefresh = DateTimeOffset.UtcNow.AddMinutes(5) > _accessToken.Value.ExpiresOn; + } + + if (needRefresh) + { + _accessToken = await new AzureCliCredential() + .GetTokenAsync(_tokenContext, cancellationToken); + } + } + catch (Exception e) when (e is not OperationCanceledException) + { + string message = $"Failed to generate the user access token: {e.Message}."; + Telemetry.Trace(AzTrace.Exception(message), e); + throw new TokenRequestException(message, e); + } + } + + /// + /// Reset the access token. + /// + internal void Reset() + { + _accessToken = null; + } +} + +internal class UserDirectLineToken +{ + private const string REFRESH_TOKEN_URL = "https://directline.botframework.com/v3/directline/tokens/refresh"; + + private string _token; + private DateTimeOffset _expireOn; + + /// + /// The DirectLine token. + /// + internal string Token => _token; + + /// + /// Initialize an instance. + /// + internal UserDirectLineToken(string token, int expiresInSec) + { + _token = token; + _expireOn = DateTimeOffset.UtcNow.AddSeconds(expiresInSec); + } + + /// + /// Check the token health. + /// + /// + internal TokenHealth CheckTokenHealth() + { + var now = DateTimeOffset.UtcNow; + if (now > _expireOn || now.AddMinutes(2) >= _expireOn) + { + return TokenHealth.Expired; + } + + if (now.AddMinutes(10) < _expireOn) + { + return TokenHealth.Good; + } + + return TokenHealth.TimeToRefresh; + } + + /// + /// Renew the DirectLine token. + /// + internal async Task RenewTokenAsync(HttpClient httpClient, CancellationToken cancellationToken) + { + TokenHealth health = CheckTokenHealth(); + if (health is TokenHealth.Expired) + { + throw new TokenRequestException("The chat session has expired. Please start a new chat session."); + } + + if (health is TokenHealth.Good) + { + return; + } + + try + { + HttpRequestMessage request = new(HttpMethod.Post, REFRESH_TOKEN_URL); + request.Headers.Authorization = new AuthenticationHeaderValue("Bearer", _token); + + var response = await httpClient.SendAsync(request, cancellationToken); + response.EnsureSuccessStatusCode(); + + using Stream content = await response.Content.ReadAsStreamAsync(cancellationToken); + RefreshDLToken dlToken = JsonSerializer.Deserialize(content, Utils.JsonOptions); + + _token = dlToken.Token; + _expireOn = DateTimeOffset.UtcNow.AddSeconds(dlToken.ExpiresIn); + } + catch (Exception e) when (e is not OperationCanceledException) + { + string message = $"Failed to renew the 'DirectLine' token: {e.Message}."; + Telemetry.Trace(AzTrace.Exception(message), e); + throw new TokenRequestException(message, e); + } + } +} + +#endregion diff --git a/shell/agents/Microsoft.Azure.Agent/Utils.cs b/shell/agents/Microsoft.Azure.Agent/Utils.cs index 5d216358..d14ceb7a 100644 --- a/shell/agents/Microsoft.Azure.Agent/Utils.cs +++ b/shell/agents/Microsoft.Azure.Agent/Utils.cs @@ -34,10 +34,26 @@ static Utils() internal static JsonSerializerOptions JsonOptions => s_jsonOptions; internal static JsonSerializerOptions JsonHumanReadableOptions => s_humanReadableOptions; internal static JsonSerializerOptions RelaxedJsonEscapingOptions => s_relaxedJsonEscapingOptions; + + internal async static Task EnsureSuccessStatusCodeForTokenRequest(this HttpResponseMessage response, string errorMessage) + { + if (!response.IsSuccessStatusCode) + { + string responseText = await response.Content.ReadAsStringAsync(CancellationToken.None); + string message = $"{errorMessage} HTTP status: {response.StatusCode}, Response: {responseText}"; + Telemetry.Trace(AzTrace.Exception(message)); + throw new TokenRequestException(message); + } + } } internal class TokenRequestException : Exception { + /// + /// Access to Copilot was denied. + /// + internal bool UserUnauthorized { get; set; } + internal TokenRequestException(string message) : base(message) {