Skip to content

Commit 37f32ef

Browse files
committed
Rewrite ParameterProvider and its usage to make sure it's invoked at the INVOKE phase
the ssm parameter lambda exention does not work at INIT time
1 parent f4c5298 commit 37f32ef

File tree

6 files changed

+79
-66
lines changed

6 files changed

+79
-66
lines changed

src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayAskAiGateway.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,20 @@
77
using System.Text.Json.Serialization;
88
using Elastic.Documentation.Api.Core.AskAi;
99
using Elastic.Documentation.Api.Infrastructure.Gcp;
10-
using Microsoft.Extensions.Options;
1110

1211
namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi;
1312

14-
public class LlmGatewayAskAiGateway(HttpClient httpClient, GcpIdTokenProvider tokenProvider, IOptionsSnapshot<LlmGatewayOptions> options) : IAskAiGateway<Stream>
13+
public class LlmGatewayAskAiGateway(HttpClient httpClient, GcpIdTokenProvider tokenProvider, LlmGatewayOptions options) : IAskAiGateway<Stream>
1514
{
1615
public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
1716
{
1817
var llmGatewayRequest = LlmGatewayRequest.CreateFromRequest(askAiRequest);
1918
var requestBody = JsonSerializer.Serialize(llmGatewayRequest, LlmGatewayContext.Default.LlmGatewayRequest);
20-
var request = new HttpRequestMessage(HttpMethod.Post, options.Value.FunctionUrl)
19+
var request = new HttpRequestMessage(HttpMethod.Post, options.FunctionUrl)
2120
{
2221
Content = new StringContent(requestBody, Encoding.UTF8, "application/json")
2322
};
24-
var authToken = await tokenProvider.GenerateIdTokenAsync(ctx);
23+
var authToken = await tokenProvider.GenerateIdTokenAsync(options.ServiceAccount, options.TargetAudience, ctx);
2524
request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", authToken);
2625
request.Headers.Add("User-Agent", "elastic-docs-proxy/1.0");
2726
request.Headers.Add("Accept", "text/event-stream");
@@ -44,7 +43,7 @@ public static LlmGatewayRequest CreateFromRequest(AskAiRequest request) =>
4443
PlatformContext: new PlatformContext("support_portal", "support_assistant", []),
4544
Input:
4645
[
47-
new ChatInput("system", AskAiRequest.SystemPrompt),
46+
new ChatInput("user", AskAiRequest.SystemPrompt),
4847
new ChatInput("user", request.Message)
4948
],
5049
ThreadId: request.ThreadId ?? "elastic-docs-" + Guid.NewGuid()
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Licensed to Elasticsearch B.V under one or more agreements.
2+
// Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
// See the LICENSE file in the project root for more information
4+
5+
using Elastic.Documentation.Api.Infrastructure.Aws;
6+
7+
namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi;
8+
9+
public class LlmGatewayOptions
10+
{
11+
public LlmGatewayOptions(IParameterProvider parameterProvider)
12+
{
13+
ServiceAccount = parameterProvider.GetParam("llm-gateway-service-account").GetAwaiter().GetResult();
14+
FunctionUrl = parameterProvider.GetParam("llm-gateway-function-url").GetAwaiter().GetResult();
15+
var uri = new Uri(FunctionUrl);
16+
TargetAudience = $"{uri.Scheme}://{uri.Host}";
17+
}
18+
19+
public string ServiceAccount { get; }
20+
public string FunctionUrl { get; }
21+
public string TargetAudience { get; }
22+
}

src/api/Elastic.Documentation.Api.Infrastructure/Aws/LambdaExtensionParameterProvider.cs

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
namespace Elastic.Documentation.Api.Infrastructure.Aws;
1111

12-
public class LambdaExtensionParameterProvider(IHttpClientFactory httpClientFactory, ILogger<LambdaExtensionParameterProvider> logger) : IParameterProvider
12+
public class LambdaExtensionParameterProvider(IHttpClientFactory httpClientFactory, AppEnvironment appEnvironment, ILogger<LambdaExtensionParameterProvider> logger) : IParameterProvider
1313
{
1414
public const string HttpClientName = "AwsParametersAndSecretsLambdaExtensionClient";
1515
private readonly HttpClient _httpClient = httpClientFactory.CreateClient(HttpClientName);
@@ -18,8 +18,10 @@ public async Task<string> GetParam(string name, bool withDecryption = true, Canc
1818
{
1919
try
2020
{
21-
logger.LogInformation("Retrieving parameter '{Name}' from Lambda Extension (SSM Parameter Store).", name);
22-
var response = await _httpClient.GetFromJsonAsync<ParameterResponse>($"/systemsmanager/parameters/get?name={Uri.EscapeDataString(name)}&withDecryption={withDecryption.ToString().ToLowerInvariant()}", AwsJsonContext.Default.ParameterResponse, ctx);
21+
var prefix = $"/elastic-docs-v3/{appEnvironment.Current.ToStringFast(true)}/";
22+
var prefixedName = prefix + name.TrimStart('/');
23+
logger.LogInformation("Retrieving parameter '{Name}' from Lambda Extension (SSM Parameter Store).", prefixedName);
24+
var response = await _httpClient.GetFromJsonAsync<ParameterResponse>($"/systemsmanager/parameters/get?name={Uri.EscapeDataString(prefixedName)}&withDecryption={withDecryption.ToString().ToLowerInvariant()}", AwsJsonContext.Default.ParameterResponse, ctx);
2325
return response?.Parameter?.Value ?? throw new InvalidOperationException($"Parameter value for '{name}' is null.");
2426
}
2527
catch (HttpRequestException httpEx)
@@ -42,23 +44,23 @@ public async Task<string> GetParam(string name, bool withDecryption = true, Canc
4244

4345
internal sealed class ParameterResponse
4446
{
45-
public Parameter? Parameter { get; set; }
47+
public required Parameter Parameter { get; set; }
4648
}
4749

4850
internal sealed class Parameter
4951
{
50-
public string? Arn { get; set; }
51-
public string? Name { get; set; }
52-
public string? Type { get; set; }
53-
public string? Value { get; set; }
54-
public string? Version { get; set; }
52+
[JsonPropertyName("ARN")]
53+
public required string Arn { get; set; }
54+
public required string Name { get; set; }
55+
public required string Type { get; set; }
56+
public required string Value { get; set; }
57+
public required int Version { get; set; }
5558
public string? Selector { get; set; }
56-
public string? LastModifiedDate { get; set; }
57-
public string? LastModifiedUser { get; set; }
58-
public string? DataType { get; set; }
59+
public DateTime LastModifiedDate { get; set; }
60+
public required string DataType { get; set; }
5961
}
6062

6163

6264
[JsonSerializable(typeof(ParameterResponse))]
63-
[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.CamelCase)]
65+
[JsonSourceGenerationOptions(PropertyNamingPolicy = JsonKnownNamingPolicy.Unspecified)]
6466
internal sealed partial class AwsJsonContext : JsonSerializerContext;

src/api/Elastic.Documentation.Api.Infrastructure/Aws/LocalParameterProvider.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public async Task<string> GetParam(string name, bool withDecryption = true, Canc
1010
{
1111
switch (name)
1212
{
13-
case "/elastic-docs-v3/dev/llm-gateway-service-account":
13+
case "llm-gateway-service-account":
1414
{
1515
const string envName = "LLM_GATEWAY_SERVICE_ACCOUNT_KEY_PATH";
1616
var serviceAccountKeyPath = Environment.GetEnvironmentVariable(envName);
@@ -21,7 +21,7 @@ public async Task<string> GetParam(string name, bool withDecryption = true, Canc
2121
var serviceAccountKey = await File.ReadAllTextAsync(serviceAccountKeyPath, ctx);
2222
return serviceAccountKey;
2323
}
24-
case "/elastic-docs-v3/dev/llm-gateway-function-url":
24+
case "llm-gateway-function-url":
2525
{
2626
const string envName = "LLM_GATEWAY_FUNCTION_URL";
2727
var value = Environment.GetEnvironmentVariable(envName);

src/api/Elastic.Documentation.Api.Infrastructure/Gcp/GcpIdTokenProvider.cs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,35 @@
66
using System.Text;
77
using System.Text.Json;
88
using System.Text.Json.Serialization;
9+
using Elastic.Documentation.Api.Infrastructure.Aws;
10+
using Microsoft.Extensions.Configuration;
911
using Microsoft.Extensions.Options;
1012

1113
namespace Elastic.Documentation.Api.Infrastructure.Gcp;
1214

1315
// This is a custom implementation to create an ID token for GCP.
1416
// Because Google.Api.Auth.OAuth2 is not compatible with AOT
15-
public class GcpIdTokenProvider(HttpClient httpClient, IOptionsSnapshot<LlmGatewayOptions> options)
17+
public class GcpIdTokenProvider(HttpClient httpClient)
1618
{
17-
public async Task<string> GenerateIdTokenAsync(Cancel cancellationToken = default)
19+
public async Task<string> GenerateIdTokenAsync(string serviceAccount, string targetAudience, Cancel cancellationToken = default)
1820
{
1921
// Read and parse service account key file using System.Text.Json source generation (AOT compatible)
20-
var serviceAccount = JsonSerializer.Deserialize(options.Value.ServiceAccount, GcpJsonContext.Default.ServiceAccountKey);
22+
var serviceAccountJson = JsonSerializer.Deserialize(serviceAccount, GcpJsonContext.Default.ServiceAccountKey);
2123

2224
// Create JWT header
23-
var header = new JwtHeader("RS256", "JWT", serviceAccount.PrivateKeyId);
25+
var header = new JwtHeader("RS256", "JWT", serviceAccountJson.PrivateKeyId);
2426
var headerJson = JsonSerializer.Serialize(header, JwtHeaderJsonContext.Default.JwtHeader);
2527
var headerBase64 = Base64UrlEncode(Encoding.UTF8.GetBytes(headerJson));
2628

2729
// Create JWT payload
2830
var now = DateTimeOffset.UtcNow.ToUnixTimeSeconds();
2931
var payload = new JwtPayload(
30-
serviceAccount.ClientEmail,
31-
serviceAccount.ClientEmail,
32+
serviceAccountJson.ClientEmail,
33+
serviceAccountJson.ClientEmail,
3234
"https://oauth2.googleapis.com/token",
3335
now,
3436
now + 300, // 5 minutes
35-
options.Value.TargetAudience
37+
targetAudience
3638
);
3739

3840
var payloadJson = JsonSerializer.Serialize(payload, GcpJsonContext.Default.JwtPayload);
@@ -43,7 +45,7 @@ public async Task<string> GenerateIdTokenAsync(Cancel cancellationToken = defaul
4345
var messageBytes = Encoding.UTF8.GetBytes(message);
4446

4547
// Parse the private key (removing PEM headers/footers and decoding)
46-
var privateKeyPem = serviceAccount.PrivateKey
48+
var privateKeyPem = serviceAccountJson.PrivateKey
4749
.Replace("-----BEGIN PRIVATE KEY-----", "")
4850
.Replace("-----END PRIVATE KEY-----", "")
4951
.Replace("\n", "")
@@ -59,7 +61,7 @@ public async Task<string> GenerateIdTokenAsync(Cancel cancellationToken = defaul
5961
var jwt = $"{message}.{signatureBase64}";
6062

6163
// Exchange JWT for ID token
62-
return await ExchangeJwtForIdToken(jwt, options.Value.TargetAudience, cancellationToken);
64+
return await ExchangeJwtForIdToken(jwt, targetAudience, cancellationToken);
6365
}
6466

6567

src/api/Elastic.Documentation.Api.Infrastructure/ServicesExtension.cs

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,17 @@
1515
namespace Elastic.Documentation.Api.Infrastructure;
1616

1717
[EnumExtensions]
18-
public enum AppEnvironment
18+
public enum AppEnv
1919
{
2020
[Display(Name = "dev")] Dev,
2121
[Display(Name = "staging")] Staging,
2222
[Display(Name = "edge")] Edge,
2323
[Display(Name = "prod")] Prod
2424
}
2525

26-
public class LlmGatewayOptions
26+
public class AppEnvironment
2727
{
28-
public string ServiceAccount { get; set; } = string.Empty;
29-
public string FunctionUrl { get; set; } = string.Empty;
30-
public string TargetAudience { get; set; } = string.Empty;
28+
public AppEnv Current { get; init; }
3129
}
3230

3331
public static class ServicesExtension
@@ -41,42 +39,44 @@ public static class ServicesExtension
4139

4240
public static void AddElasticDocsApiUsecases(this IServiceCollection services, string? appEnvironment)
4341
{
44-
if (AppEnvironmentExtensions.TryParse(appEnvironment, out var parsedEnvironment, true))
42+
if (AppEnvExtensions.TryParse(appEnvironment, out var parsedEnvironment, true))
4543
{
4644
AddElasticDocsApiUsecases(services, parsedEnvironment);
4745
}
4846
else
4947
{
5048
var logger = GetLogger(services);
5149
logger?.LogWarning("Unable to parse environment {AppEnvironment} into AppEnvironment. Using default AppEnvironment.Dev", appEnvironment);
52-
AddElasticDocsApiUsecases(services, AppEnvironment.Dev);
50+
AddElasticDocsApiUsecases(services, AppEnv.Dev);
5351
}
5452
}
5553

5654

57-
private static void AddElasticDocsApiUsecases(this IServiceCollection services, AppEnvironment appEnvironment)
55+
private static void AddElasticDocsApiUsecases(this IServiceCollection services, AppEnv appEnv)
5856
{
5957
_ = services.ConfigureHttpJsonOptions(options =>
6058
{
6159
options.SerializerOptions.TypeInfoResolverChain.Insert(0, ApiJsonContext.Default);
6260
});
6361
_ = services.AddHttpClient();
64-
AddParameterProvider(services, appEnvironment);
65-
AddAskAiUsecase(services, appEnvironment);
62+
// Register AppEnvironment as a singleton for dependency injection
63+
_ = services.AddSingleton(new AppEnvironment { Current = appEnv });
64+
AddParameterProvider(services, appEnv);
65+
AddAskAiUsecase(services, appEnv);
6666
}
6767

68-
// https://docs.aws.amazon.com/systems-manager/latest/userguide/ps-integration-lambda-extensions.html
69-
private static void AddParameterProvider(IServiceCollection services, AppEnvironment appEnvironment)
68+
// https://docs.aws.amazon.com/systems -manager/latest/userguide/ps-integration-lambda-extensions.html
69+
private static void AddParameterProvider(IServiceCollection services, AppEnv appEnv)
7070
{
7171
var logger = GetLogger(services);
7272

73-
switch (appEnvironment)
73+
switch (appEnv)
7474
{
75-
case AppEnvironment.Prod:
76-
case AppEnvironment.Staging:
77-
case AppEnvironment.Edge:
75+
case AppEnv.Prod:
76+
case AppEnv.Staging:
77+
case AppEnv.Edge:
7878
{
79-
logger?.LogInformation("Configuring LambdaExtensionParameterProvider for environment {AppEnvironment}", appEnvironment);
79+
logger?.LogInformation("Configuring LambdaExtensionParameterProvider for environment {AppEnvironment}", appEnv);
8080
_ = services.AddHttpClient(LambdaExtensionParameterProvider.HttpClientName, client =>
8181
{
8282
client.BaseAddress = new Uri("http://localhost:2773");
@@ -85,39 +85,27 @@ private static void AddParameterProvider(IServiceCollection services, AppEnviron
8585
_ = services.AddSingleton<IParameterProvider, LambdaExtensionParameterProvider>();
8686
break;
8787
}
88-
case AppEnvironment.Dev:
88+
case AppEnv.Dev:
8989
{
90-
logger?.LogInformation("Configuring LocalParameterProvider for environment {AppEnvironment}", appEnvironment);
90+
logger?.LogInformation("Configuring LocalParameterProvider for environment {AppEnvironment}", appEnv);
9191
_ = services.AddSingleton<IParameterProvider, LocalParameterProvider>();
9292
break;
9393
}
9494
default:
9595
{
96-
throw new ArgumentOutOfRangeException(nameof(appEnvironment), appEnvironment,
96+
throw new ArgumentOutOfRangeException(nameof(appEnv), appEnv,
9797
"Unsupported environment for parameter provider.");
9898
}
9999
}
100100
}
101101

102-
private static void AddAskAiUsecase(IServiceCollection services, AppEnvironment appEnvironment)
102+
private static void AddAskAiUsecase(IServiceCollection services, AppEnv appEnv)
103103
{
104104
var logger = GetLogger(services);
105-
logger?.LogInformation("Configuring AskAi use case for environment {AppEnvironment}", appEnvironment);
106-
107-
_ = services.Configure<LlmGatewayOptions>(options =>
108-
{
109-
var serviceProvider = services.BuildServiceProvider();
110-
var parameterProvider = serviceProvider.GetRequiredService<IParameterProvider>();
111-
var appEnvString = appEnvironment.ToStringFast(true);
112-
113-
options.ServiceAccount = parameterProvider.GetParam($"/elastic-docs-v3/{appEnvString}/llm-gateway-service-account").GetAwaiter().GetResult();
114-
options.FunctionUrl = parameterProvider.GetParam($"/elastic-docs-v3/{appEnvString}/llm-gateway-function-url").GetAwaiter().GetResult();
115-
116-
var functionUri = new Uri(options.FunctionUrl);
117-
options.TargetAudience = $"{functionUri.Scheme}://{functionUri.Host}";
118-
});
119-
_ = services.AddScoped<GcpIdTokenProvider>();
120-
_ = services.AddScoped<IAskAiGateway<Stream>, LlmGatewayAskAiGateway>();
105+
logger?.LogInformation("Configuring AskAi use case for environment {AppEnvironment}", appEnv);
106+
_ = services.AddSingleton<GcpIdTokenProvider>();
107+
_ = services.AddSingleton<IAskAiGateway<Stream>, LlmGatewayAskAiGateway>();
108+
_ = services.AddScoped<LlmGatewayOptions>();
121109
_ = services.AddScoped<AskAiUsecase>();
122110
}
123111
}

0 commit comments

Comments
 (0)