diff --git a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayAskAiGateway.cs b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayAskAiGateway.cs index 79155b654..fb7236c72 100644 --- a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayAskAiGateway.cs +++ b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayAskAiGateway.cs @@ -7,21 +7,20 @@ using System.Text.Json.Serialization; using Elastic.Documentation.Api.Core.AskAi; using Elastic.Documentation.Api.Infrastructure.Gcp; -using Microsoft.Extensions.Options; namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi; -public class LlmGatewayAskAiGateway(HttpClient httpClient, GcpIdTokenProvider tokenProvider, IOptionsSnapshot options) : IAskAiGateway +public class LlmGatewayAskAiGateway(HttpClient httpClient, GcpIdTokenProvider tokenProvider, LlmGatewayOptions options) : IAskAiGateway { public async Task AskAi(AskAiRequest askAiRequest, Cancel ctx = default) { var llmGatewayRequest = LlmGatewayRequest.CreateFromRequest(askAiRequest); var requestBody = JsonSerializer.Serialize(llmGatewayRequest, LlmGatewayContext.Default.LlmGatewayRequest); - var request = new HttpRequestMessage(HttpMethod.Post, options.Value.FunctionUrl) + var request = new HttpRequestMessage(HttpMethod.Post, options.FunctionUrl) { Content = new StringContent(requestBody, Encoding.UTF8, "application/json") }; - var authToken = await tokenProvider.GenerateIdTokenAsync(ctx); + var authToken = await tokenProvider.GenerateIdTokenAsync(options.ServiceAccount, options.TargetAudience, ctx); request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", authToken); request.Headers.Add("User-Agent", "elastic-docs-proxy/1.0"); request.Headers.Add("Accept", "text/event-stream"); @@ -44,7 +43,7 @@ public static LlmGatewayRequest CreateFromRequest(AskAiRequest request) => PlatformContext: new PlatformContext("support_portal", "support_assistant", []), Input: [ - new ChatInput("system", AskAiRequest.SystemPrompt), + new ChatInput("user", AskAiRequest.SystemPrompt), new ChatInput("user", request.Message) ], ThreadId: request.ThreadId ?? "elastic-docs-" + Guid.NewGuid() diff --git a/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayOptions.cs b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayOptions.cs new file mode 100644 index 000000000..b9f3ec905 --- /dev/null +++ b/src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayOptions.cs @@ -0,0 +1,22 @@ +// Licensed to Elasticsearch B.V under one or more agreements. +// Elasticsearch B.V licenses this file to you under the Apache 2.0 License. +// See the LICENSE file in the project root for more information + +using Elastic.Documentation.Api.Infrastructure.Aws; + +namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi; + +public class LlmGatewayOptions +{ + public LlmGatewayOptions(IParameterProvider parameterProvider) + { + ServiceAccount = parameterProvider.GetParam("llm-gateway-service-account").GetAwaiter().GetResult(); + FunctionUrl = parameterProvider.GetParam("llm-gateway-function-url").GetAwaiter().GetResult(); + var uri = new Uri(FunctionUrl); + TargetAudience = $"{uri.Scheme}://{uri.Host}"; + } + + public string ServiceAccount { get; } + public string FunctionUrl { get; } + public string TargetAudience { get; } +} diff --git a/src/api/Elastic.Documentation.Api.Infrastructure/Aws/LambdaExtensionParameterProvider.cs b/src/api/Elastic.Documentation.Api.Infrastructure/Aws/LambdaExtensionParameterProvider.cs index 651e72e19..37bb8519f 100644 --- a/src/api/Elastic.Documentation.Api.Infrastructure/Aws/LambdaExtensionParameterProvider.cs +++ b/src/api/Elastic.Documentation.Api.Infrastructure/Aws/LambdaExtensionParameterProvider.cs @@ -9,7 +9,7 @@ namespace Elastic.Documentation.Api.Infrastructure.Aws; -public class LambdaExtensionParameterProvider(IHttpClientFactory httpClientFactory, ILogger logger) : IParameterProvider +public class LambdaExtensionParameterProvider(IHttpClientFactory httpClientFactory, AppEnvironment appEnvironment, ILogger logger) : IParameterProvider { public const string HttpClientName = "AwsParametersAndSecretsLambdaExtensionClient"; private readonly HttpClient _httpClient = httpClientFactory.CreateClient(HttpClientName); @@ -18,8 +18,10 @@ public async Task GetParam(string name, bool withDecryption = true, Canc { try { - logger.LogInformation("Retrieving parameter '{Name}' from Lambda Extension (SSM Parameter Store).", name); - var response = await _httpClient.GetFromJsonAsync($"/systemsmanager/parameters/get?name={Uri.EscapeDataString(name)}&withDecryption={withDecryption.ToString().ToLowerInvariant()}", AwsJsonContext.Default.ParameterResponse, ctx); + var prefix = $"/elastic-docs-v3/{appEnvironment.Current.ToStringFast(true)}/"; + var prefixedName = prefix + name.TrimStart('/'); + logger.LogInformation("Retrieving parameter '{Name}' from Lambda Extension (SSM Parameter Store).", prefixedName); + var response = await _httpClient.GetFromJsonAsync($"/systemsmanager/parameters/get?name={Uri.EscapeDataString(prefixedName)}&withDecryption={withDecryption.ToString().ToLowerInvariant()}", AwsJsonContext.Default.ParameterResponse, ctx); return response?.Parameter?.Value ?? throw new InvalidOperationException($"Parameter value for '{name}' is null."); } catch (HttpRequestException httpEx) @@ -42,23 +44,23 @@ public async Task GetParam(string name, bool withDecryption = true, Canc internal sealed class ParameterResponse { - public Parameter? Parameter { get; set; } + public required Parameter Parameter { get; set; } } internal sealed class Parameter { - public string? Arn { get; set; } - public string? Name { get; set; } - public string? Type { get; set; } - public string? Value { get; set; } - public string? Version { get; set; } + [JsonPropertyName("ARN")] + public required string Arn { get; set; } + public required string Name { get; set; } + public required string Type { get; set; } + public required string Value { get; set; } + public required int Version { get; set; } public string? Selector { get; set; } - public string? LastModifiedDate { get; set; } - public string? LastModifiedUser { get; set; } - public string? DataType { get; set; } + public DateTime LastModifiedDate { get; set; } + public required string DataType { get; set; } } [JsonSerializable(typeof(ParameterResponse))] -[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)] +[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.Unspecified)] internal sealed partial class AwsJsonContext : JsonSerializerContext; diff --git a/src/api/Elastic.Documentation.Api.Infrastructure/Aws/LocalParameterProvider.cs b/src/api/Elastic.Documentation.Api.Infrastructure/Aws/LocalParameterProvider.cs index c8974fb63..66d817d61 100644 --- a/src/api/Elastic.Documentation.Api.Infrastructure/Aws/LocalParameterProvider.cs +++ b/src/api/Elastic.Documentation.Api.Infrastructure/Aws/LocalParameterProvider.cs @@ -10,7 +10,7 @@ public async Task GetParam(string name, bool withDecryption = true, Canc { switch (name) { - case "/elastic-docs-v3/dev/llm-gateway-service-account": + case "llm-gateway-service-account": { const string envName = "LLM_GATEWAY_SERVICE_ACCOUNT_KEY_PATH"; var serviceAccountKeyPath = Environment.GetEnvironmentVariable(envName); @@ -21,7 +21,7 @@ public async Task GetParam(string name, bool withDecryption = true, Canc var serviceAccountKey = await File.ReadAllTextAsync(serviceAccountKeyPath, ctx); return serviceAccountKey; } - case "/elastic-docs-v3/dev/llm-gateway-function-url": + case "llm-gateway-function-url": { const string envName = "LLM_GATEWAY_FUNCTION_URL"; var value = Environment.GetEnvironmentVariable(envName); diff --git a/src/api/Elastic.Documentation.Api.Infrastructure/Gcp/GcpIdTokenProvider.cs b/src/api/Elastic.Documentation.Api.Infrastructure/Gcp/GcpIdTokenProvider.cs index 3c06e6345..c426a279d 100644 --- a/src/api/Elastic.Documentation.Api.Infrastructure/Gcp/GcpIdTokenProvider.cs +++ b/src/api/Elastic.Documentation.Api.Infrastructure/Gcp/GcpIdTokenProvider.cs @@ -2,37 +2,48 @@ // Elasticsearch B.V licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information +using System.Collections.Concurrent; using System.Security.Cryptography; using System.Text; using System.Text.Json; using System.Text.Json.Serialization; -using Microsoft.Extensions.Options; namespace Elastic.Documentation.Api.Infrastructure.Gcp; // This is a custom implementation to create an ID token for GCP. // Because Google.Api.Auth.OAuth2 is not compatible with AOT -public class GcpIdTokenProvider(HttpClient httpClient, IOptionsSnapshot options) +public class GcpIdTokenProvider(HttpClient httpClient) { - public async Task GenerateIdTokenAsync(Cancel cancellationToken = default) + // Cache tokens by target audience to avoid regenerating them on every request + private static readonly ConcurrentDictionary TokenCache = new(); + + private sealed record CachedToken(string Token, DateTimeOffset ExpiresAt); + + public async Task GenerateIdTokenAsync(string serviceAccount, string targetAudience, Cancel cancellationToken = default) { + // Check if we have a valid cached token + if (TokenCache.TryGetValue(targetAudience, out var cachedToken) && + cachedToken.ExpiresAt > DateTimeOffset.UtcNow.AddMinutes(1)) // Refresh 1 minute before expiry + return cachedToken.Token; + // Read and parse service account key file using System.Text.Json source generation (AOT compatible) - var serviceAccount = JsonSerializer.Deserialize(options.Value.ServiceAccount, GcpJsonContext.Default.ServiceAccountKey); + var serviceAccountJson = JsonSerializer.Deserialize(serviceAccount, GcpJsonContext.Default.ServiceAccountKey); // Create JWT header - var header = new JwtHeader("RS256", "JWT", serviceAccount.PrivateKeyId); + var header = new JwtHeader("RS256", "JWT", serviceAccountJson.PrivateKeyId); var headerJson = JsonSerializer.Serialize(header, JwtHeaderJsonContext.Default.JwtHeader); var headerBase64 = Base64UrlEncode(Encoding.UTF8.GetBytes(headerJson)); // Create JWT payload - var now = DateTimeOffset.UtcNow.ToUnixTimeSeconds(); + var now = DateTimeOffset.UtcNow; + var expirationTime = now.AddHours(1); var payload = new JwtPayload( - serviceAccount.ClientEmail, - serviceAccount.ClientEmail, + serviceAccountJson.ClientEmail, + serviceAccountJson.ClientEmail, "https://oauth2.googleapis.com/token", - now, - now + 300, // 5 minutes - options.Value.TargetAudience + now.ToUnixTimeSeconds(), + expirationTime.ToUnixTimeSeconds(), + targetAudience ); var payloadJson = JsonSerializer.Serialize(payload, GcpJsonContext.Default.JwtPayload); @@ -43,7 +54,7 @@ public async Task GenerateIdTokenAsync(Cancel cancellationToken = defaul var messageBytes = Encoding.UTF8.GetBytes(message); // Parse the private key (removing PEM headers/footers and decoding) - var privateKeyPem = serviceAccount.PrivateKey + var privateKeyPem = serviceAccountJson.PrivateKey .Replace("-----BEGIN PRIVATE KEY-----", "") .Replace("-----END PRIVATE KEY-----", "") .Replace("\n", "") @@ -59,7 +70,14 @@ public async Task GenerateIdTokenAsync(Cancel cancellationToken = defaul var jwt = $"{message}.{signatureBase64}"; // Exchange JWT for ID token - return await ExchangeJwtForIdToken(jwt, options.Value.TargetAudience, cancellationToken); + var idToken = await ExchangeJwtForIdToken(jwt, targetAudience, cancellationToken); + + var expiresAt = expirationTime.Subtract(TimeSpan.FromMinutes(1)); + _ = TokenCache.AddOrUpdate(targetAudience, + new CachedToken(idToken, expiresAt), + (_, _) => new CachedToken(idToken, expiresAt)); + + return idToken; } diff --git a/src/api/Elastic.Documentation.Api.Infrastructure/ServicesExtension.cs b/src/api/Elastic.Documentation.Api.Infrastructure/ServicesExtension.cs index 3bb7ecf15..c8e37f029 100644 --- a/src/api/Elastic.Documentation.Api.Infrastructure/ServicesExtension.cs +++ b/src/api/Elastic.Documentation.Api.Infrastructure/ServicesExtension.cs @@ -15,7 +15,7 @@ namespace Elastic.Documentation.Api.Infrastructure; [EnumExtensions] -public enum AppEnvironment +public enum AppEnv { [Display(Name = "dev")] Dev, [Display(Name = "staging")] Staging, @@ -23,11 +23,9 @@ public enum AppEnvironment [Display(Name = "prod")] Prod } -public class LlmGatewayOptions +public class AppEnvironment { - public string ServiceAccount { get; set; } = string.Empty; - public string FunctionUrl { get; set; } = string.Empty; - public string TargetAudience { get; set; } = string.Empty; + public AppEnv Current { get; init; } } public static class ServicesExtension @@ -41,7 +39,7 @@ public static class ServicesExtension public static void AddElasticDocsApiUsecases(this IServiceCollection services, string? appEnvironment) { - if (AppEnvironmentExtensions.TryParse(appEnvironment, out var parsedEnvironment, true)) + if (AppEnvExtensions.TryParse(appEnvironment, out var parsedEnvironment, true)) { AddElasticDocsApiUsecases(services, parsedEnvironment); } @@ -49,34 +47,36 @@ public static void AddElasticDocsApiUsecases(this IServiceCollection services, s { var logger = GetLogger(services); logger?.LogWarning("Unable to parse environment {AppEnvironment} into AppEnvironment. Using default AppEnvironment.Dev", appEnvironment); - AddElasticDocsApiUsecases(services, AppEnvironment.Dev); + AddElasticDocsApiUsecases(services, AppEnv.Dev); } } - private static void AddElasticDocsApiUsecases(this IServiceCollection services, AppEnvironment appEnvironment) + private static void AddElasticDocsApiUsecases(this IServiceCollection services, AppEnv appEnv) { _ = services.ConfigureHttpJsonOptions(options => { options.SerializerOptions.TypeInfoResolverChain.Insert(0, ApiJsonContext.Default); }); _ = services.AddHttpClient(); - AddParameterProvider(services, appEnvironment); - AddAskAiUsecase(services, appEnvironment); + // Register AppEnvironment as a singleton for dependency injection + _ = services.AddSingleton(new AppEnvironment { Current = appEnv }); + AddParameterProvider(services, appEnv); + AddAskAiUsecase(services, appEnv); } - // https://docs.aws.amazon.com/systems-manager/latest/userguide/ps-integration-lambda-extensions.html - private static void AddParameterProvider(IServiceCollection services, AppEnvironment appEnvironment) + // https://docs.aws.amazon.com/systems -manager/latest/userguide/ps-integration-lambda-extensions.html + private static void AddParameterProvider(IServiceCollection services, AppEnv appEnv) { var logger = GetLogger(services); - switch (appEnvironment) + switch (appEnv) { - case AppEnvironment.Prod: - case AppEnvironment.Staging: - case AppEnvironment.Edge: + case AppEnv.Prod: + case AppEnv.Staging: + case AppEnv.Edge: { - logger?.LogInformation("Configuring LambdaExtensionParameterProvider for environment {AppEnvironment}", appEnvironment); + logger?.LogInformation("Configuring LambdaExtensionParameterProvider for environment {AppEnvironment}", appEnv); _ = services.AddHttpClient(LambdaExtensionParameterProvider.HttpClientName, client => { client.BaseAddress = new Uri("http://localhost:2773"); @@ -85,39 +85,27 @@ private static void AddParameterProvider(IServiceCollection services, AppEnviron _ = services.AddSingleton(); break; } - case AppEnvironment.Dev: + case AppEnv.Dev: { - logger?.LogInformation("Configuring LocalParameterProvider for environment {AppEnvironment}", appEnvironment); + logger?.LogInformation("Configuring LocalParameterProvider for environment {AppEnvironment}", appEnv); _ = services.AddSingleton(); break; } default: { - throw new ArgumentOutOfRangeException(nameof(appEnvironment), appEnvironment, + throw new ArgumentOutOfRangeException(nameof(appEnv), appEnv, "Unsupported environment for parameter provider."); } } } - private static void AddAskAiUsecase(IServiceCollection services, AppEnvironment appEnvironment) + private static void AddAskAiUsecase(IServiceCollection services, AppEnv appEnv) { var logger = GetLogger(services); - logger?.LogInformation("Configuring AskAi use case for environment {AppEnvironment}", appEnvironment); - - _ = services.Configure(options => - { - var serviceProvider = services.BuildServiceProvider(); - var parameterProvider = serviceProvider.GetRequiredService(); - var appEnvString = appEnvironment.ToStringFast(true); - - options.ServiceAccount = parameterProvider.GetParam($"/elastic-docs-v3/{appEnvString}/llm-gateway-service-account").GetAwaiter().GetResult(); - options.FunctionUrl = parameterProvider.GetParam($"/elastic-docs-v3/{appEnvString}/llm-gateway-function-url").GetAwaiter().GetResult(); - - var functionUri = new Uri(options.FunctionUrl); - options.TargetAudience = $"{functionUri.Scheme}://{functionUri.Host}"; - }); - _ = services.AddScoped(); - _ = services.AddScoped, LlmGatewayAskAiGateway>(); + logger?.LogInformation("Configuring AskAi use case for environment {AppEnvironment}", appEnv); + _ = services.AddSingleton(); + _ = services.AddSingleton, LlmGatewayAskAiGateway>(); + _ = services.AddScoped(); _ = services.AddScoped(); } } diff --git a/src/api/Elastic.Documentation.Api.Lambda/Elastic.Documentation.Api.Lambda.csproj b/src/api/Elastic.Documentation.Api.Lambda/Elastic.Documentation.Api.Lambda.csproj index 5d49a07a4..19d4574f6 100644 --- a/src/api/Elastic.Documentation.Api.Lambda/Elastic.Documentation.Api.Lambda.csproj +++ b/src/api/Elastic.Documentation.Api.Lambda/Elastic.Documentation.Api.Lambda.csproj @@ -13,12 +13,11 @@ true true true - true false Linux - true - $(InterceptorsPreviewNamespaces);Microsoft.AspNetCore.Http.Generated + + Elastic.Documentation.Api.Lambda diff --git a/src/api/Elastic.Documentation.Api.Lambda/Program.cs b/src/api/Elastic.Documentation.Api.Lambda/Program.cs index 8ea055886..a5c6da578 100644 --- a/src/api/Elastic.Documentation.Api.Lambda/Program.cs +++ b/src/api/Elastic.Documentation.Api.Lambda/Program.cs @@ -2,23 +2,26 @@ // Elasticsearch B.V licenses this file to you under the Apache 2.0 License. // See the LICENSE file in the project root for more information +using System.Text.Json; using System.Text.Json.Serialization; using Amazon.Lambda.APIGatewayEvents; using Amazon.Lambda.Serialization.SystemTextJson; +using Elastic.Documentation.Api.Core.AskAi; using Elastic.Documentation.Api.Infrastructure; var builder = WebApplication.CreateSlimBuilder(args); builder.Services.AddAWSLambdaHosting(LambdaEventSource.RestApi, new SourceGeneratorLambdaJsonSerializer()); -builder.Services.AddElasticDocsApiUsecases(Environment.GetEnvironmentVariable("APP_ENVIRONMENT")); +builder.Services.AddElasticDocsApiUsecases(Environment.GetEnvironmentVariable("ENVIRONMENT")); var app = builder.Build(); -var v1 = app.MapGroup("/v1"); +var v1 = app.MapGroup("/docs/_api/v1"); v1.MapElasticDocsApiEndpoints(); app.Run(); -[JsonSerializable(typeof(APIGatewayHttpApiV2ProxyRequest), GenerationMode = JsonSourceGenerationMode.Metadata)] -[JsonSerializable(typeof(APIGatewayHttpApiV2ProxyResponse), GenerationMode = JsonSourceGenerationMode.Default)] +[JsonSerializable(typeof(APIGatewayProxyRequest))] +[JsonSerializable(typeof(APIGatewayProxyResponse))] +[JsonSerializable(typeof(AskAiRequest))] internal sealed partial class LambdaJsonSerializerContext : JsonSerializerContext;