Skip to content

Commit d9d2923

Browse files
committed
Change useage of ParameterProvider
Because params can only be fetched at invoke time
1 parent 6a6fa57 commit d9d2923

File tree

6 files changed

+69
-55
lines changed

6 files changed

+69
-55
lines changed

src/api/Elastic.Documentation.Api.Infrastructure/Adapters/AskAi/LlmGatewayAskAiGateway.cs

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,20 @@
77
using System.Text.Json.Serialization;
88
using Elastic.Documentation.Api.Core.AskAi;
99
using Elastic.Documentation.Api.Infrastructure.Gcp;
10-
using Microsoft.Extensions.Options;
1110

1211
namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi;
1312

14-
public class LlmGatewayAskAiGateway(HttpClient httpClient, GcpIdTokenProvider tokenProvider, IOptionsSnapshot<LlmGatewayOptions> options) : IAskAiGateway<Stream>
13+
public class LlmGatewayAskAiGateway(HttpClient httpClient, GcpIdTokenProvider tokenProvider, LlmGatewayOptions options) : IAskAiGateway<Stream>
1514
{
1615
public async Task<Stream> AskAi(AskAiRequest askAiRequest, Cancel ctx = default)
1716
{
1817
var llmGatewayRequest = LlmGatewayRequest.CreateFromRequest(askAiRequest);
1918
var requestBody = JsonSerializer.Serialize(llmGatewayRequest, LlmGatewayContext.Default.LlmGatewayRequest);
20-
var request = new HttpRequestMessage(HttpMethod.Post, options.Value.FunctionUrl)
19+
var request = new HttpRequestMessage(HttpMethod.Post, options.FunctionUrl)
2120
{
2221
Content = new StringContent(requestBody, Encoding.UTF8, "application/json")
2322
};
24-
var authToken = await tokenProvider.GenerateIdTokenAsync(ctx);
23+
var authToken = await tokenProvider.GenerateIdTokenAsync(options.ServiceAccount, options.TargetAudience, ctx);
2524
request.Headers.Authorization = new System.Net.Http.Headers.AuthenticationHeaderValue("Bearer", authToken);
2625
request.Headers.Add("User-Agent", "elastic-docs-proxy/1.0");
2726
request.Headers.Add("Accept", "text/event-stream");
@@ -44,7 +43,7 @@ public static LlmGatewayRequest CreateFromRequest(AskAiRequest request) =>
4443
PlatformContext: new PlatformContext("support_portal", "support_assistant", []),
4544
Input:
4645
[
47-
new ChatInput("system", AskAiRequest.SystemPrompt),
46+
new ChatInput("user", AskAiRequest.SystemPrompt),
4847
new ChatInput("user", request.Message)
4948
],
5049
ThreadId: request.ThreadId ?? "elastic-docs-" + Guid.NewGuid()
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
// Licensed to Elasticsearch B.V under one or more agreements.
2+
// Elasticsearch B.V licenses this file to you under the Apache 2.0 License.
3+
// See the LICENSE file in the project root for more information
4+
5+
using Elastic.Documentation.Api.Infrastructure.Aws;
6+
7+
namespace Elastic.Documentation.Api.Infrastructure.Adapters.AskAi;
8+
9+
public class LlmGatewayOptions
10+
{
11+
12+
public LlmGatewayOptions(IParameterProvider parameterProvider)
13+
{
14+
ServiceAccount = parameterProvider.GetParam("llm-gateway-service-account", true).GetAwaiter().GetResult();
15+
FunctionUrl = parameterProvider.GetParam("llm-gateway-function-url", true).GetAwaiter().GetResult();
16+
var uri = new Uri(FunctionUrl);
17+
TargetAudience = $"{uri.Scheme}://{uri.Host}";
18+
}
19+
20+
public string ServiceAccount { get; }
21+
public string FunctionUrl { get; }
22+
public string TargetAudience { get; }
23+
}

src/api/Elastic.Documentation.Api.Infrastructure/Aws/LambdaExtensionParameterProvider.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
namespace Elastic.Documentation.Api.Infrastructure.Aws;
1111

12-
public class LambdaExtensionParameterProvider(IHttpClientFactory httpClientFactory, ILogger<LambdaExtensionParameterProvider> logger) : IParameterProvider
12+
public class LambdaExtensionParameterProvider(IHttpClientFactory httpClientFactory, AppEnvironment appEnvironment, ILogger<LambdaExtensionParameterProvider> logger) : IParameterProvider
1313
{
1414
public const string HttpClientName = "AwsParametersAndSecretsLambdaExtensionClient";
1515
private readonly HttpClient _httpClient = httpClientFactory.CreateClient(HttpClientName);
@@ -19,7 +19,9 @@ public async Task<string> GetParam(string name, bool withDecryption = true, Canc
1919
try
2020
{
2121
logger.LogInformation("Retrieving parameter '{Name}' from Lambda Extension (SSM Parameter Store).", name);
22-
var response = await _httpClient.GetFromJsonAsync<ParameterResponse>($"/systemsmanager/parameters/get?name={Uri.EscapeDataString(name)}&withDecryption={withDecryption.ToString().ToLowerInvariant()}", AwsJsonContext.Default.ParameterResponse, ctx);
22+
var prefix = $"$/elastic-docs-v3/{appEnvironment.Current.ToStringFast(true)}/";
23+
var prefixedName = prefix + name.TrimStart('/');
24+
var response = await _httpClient.GetFromJsonAsync<ParameterResponse>($"/systemsmanager/parameters/get?name={Uri.EscapeDataString(prefixedName)}&withDecryption={withDecryption.ToString().ToLowerInvariant()}", AwsJsonContext.Default.ParameterResponse, ctx);
2325
return response?.Parameter?.Value ?? throw new InvalidOperationException($"Parameter value for '{name}' is null.");
2426
}
2527
catch (HttpRequestException httpEx)

src/api/Elastic.Documentation.Api.Infrastructure/Aws/LocalParameterProvider.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ public async Task<string> GetParam(string name, bool withDecryption = true, Canc
1010
{
1111
switch (name)
1212
{
13-
case "/elastic-docs-v3/dev/llm-gateway-service-account":
13+
case "llm-gateway-service-account":
1414
{
1515
const string envName = "LLM_GATEWAY_SERVICE_ACCOUNT_KEY_PATH";
1616
var serviceAccountKeyPath = Environment.GetEnvironmentVariable(envName);
@@ -21,7 +21,7 @@ public async Task<string> GetParam(string name, bool withDecryption = true, Canc
2121
var serviceAccountKey = await File.ReadAllTextAsync(serviceAccountKeyPath, ctx);
2222
return serviceAccountKey;
2323
}
24-
case "/elastic-docs-v3/dev/llm-gateway-function-url":
24+
case "llm-gateway-function-url":
2525
{
2626
const string envName = "LLM_GATEWAY_FUNCTION_URL";
2727
var value = Environment.GetEnvironmentVariable(envName);

src/api/Elastic.Documentation.Api.Infrastructure/Gcp/GcpIdTokenProvider.cs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,33 +6,35 @@
66
using System.Text;
77
using System.Text.Json;
88
using System.Text.Json.Serialization;
9+
using Elastic.Documentation.Api.Infrastructure.Aws;
10+
using Microsoft.Extensions.Configuration;
911
using Microsoft.Extensions.Options;
1012

1113
namespace Elastic.Documentation.Api.Infrastructure.Gcp;
1214

1315
// This is a custom implementation to create an ID token for GCP.
1416
// Because Google.Api.Auth.OAuth2 is not compatible with AOT
15-
public class GcpIdTokenProvider(HttpClient httpClient, IOptionsSnapshot<LlmGatewayOptions> options)
17+
public class GcpIdTokenProvider(HttpClient httpClient)
1618
{
17-
public async Task<string> GenerateIdTokenAsync(Cancel cancellationToken = default)
19+
public async Task<string> GenerateIdTokenAsync(string serviceAccount, string targetAudience, Cancel cancellationToken = default)
1820
{
1921
// Read and parse service account key file using System.Text.Json source generation (AOT compatible)
20-
var serviceAccount = JsonSerializer.Deserialize(options.Value.ServiceAccount, GcpJsonContext.Default.ServiceAccountKey);
22+
var serviceAccountJson = JsonSerializer.Deserialize(serviceAccount, GcpJsonContext.Default.ServiceAccountKey);
2123

2224
// Create JWT header
23-
var header = new JwtHeader("RS256", "JWT", serviceAccount.PrivateKeyId);
25+
var header = new JwtHeader("RS256", "JWT", serviceAccountJson.PrivateKeyId);
2426
var headerJson = JsonSerializer.Serialize(header, JwtHeaderJsonContext.Default.JwtHeader);
2527
var headerBase64 = Base64UrlEncode(Encoding.UTF8.GetBytes(headerJson));
2628

2729
// Create JWT payload
2830
var now = DateTimeOffset.UtcNow.ToUnixTimeSeconds();
2931
var payload = new JwtPayload(
30-
serviceAccount.ClientEmail,
31-
serviceAccount.ClientEmail,
32+
serviceAccountJson.ClientEmail,
33+
serviceAccountJson.ClientEmail,
3234
"https://oauth2.googleapis.com/token",
3335
now,
3436
now + 300, // 5 minutes
35-
options.Value.TargetAudience
37+
targetAudience
3638
);
3739

3840
var payloadJson = JsonSerializer.Serialize(payload, GcpJsonContext.Default.JwtPayload);
@@ -43,7 +45,7 @@ public async Task<string> GenerateIdTokenAsync(Cancel cancellationToken = defaul
4345
var messageBytes = Encoding.UTF8.GetBytes(message);
4446

4547
// Parse the private key (removing PEM headers/footers and decoding)
46-
var privateKeyPem = serviceAccount.PrivateKey
48+
var privateKeyPem = serviceAccountJson.PrivateKey
4749
.Replace("-----BEGIN PRIVATE KEY-----", "")
4850
.Replace("-----END PRIVATE KEY-----", "")
4951
.Replace("\n", "")
@@ -59,7 +61,7 @@ public async Task<string> GenerateIdTokenAsync(Cancel cancellationToken = defaul
5961
var jwt = $"{message}.{signatureBase64}";
6062

6163
// Exchange JWT for ID token
62-
return await ExchangeJwtForIdToken(jwt, options.Value.TargetAudience, cancellationToken);
64+
return await ExchangeJwtForIdToken(jwt, targetAudience, cancellationToken);
6365
}
6466

6567

src/api/Elastic.Documentation.Api.Infrastructure/ServicesExtension.cs

Lines changed: 25 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,17 @@
1515
namespace Elastic.Documentation.Api.Infrastructure;
1616

1717
[EnumExtensions]
18-
public enum AppEnvironment
18+
public enum AppEnv
1919
{
2020
[Display(Name = "dev")] Dev,
2121
[Display(Name = "staging")] Staging,
2222
[Display(Name = "edge")] Edge,
2323
[Display(Name = "prod")] Prod
2424
}
2525

26-
public class LlmGatewayOptions
26+
public class AppEnvironment
2727
{
28-
public string ServiceAccount { get; set; } = string.Empty;
29-
public string FunctionUrl { get; set; } = string.Empty;
30-
public string TargetAudience { get; set; } = string.Empty;
28+
public AppEnv Current { get; init; }
3129
}
3230

3331
public static class ServicesExtension
@@ -41,42 +39,44 @@ public static class ServicesExtension
4139

4240
public static void AddElasticDocsApiUsecases(this IServiceCollection services, string? appEnvironment)
4341
{
44-
if (AppEnvironmentExtensions.TryParse(appEnvironment, out var parsedEnvironment, true))
42+
if (AppEnvExtensions.TryParse(appEnvironment, out var parsedEnvironment, true))
4543
{
4644
AddElasticDocsApiUsecases(services, parsedEnvironment);
4745
}
4846
else
4947
{
5048
var logger = GetLogger(services);
5149
logger?.LogWarning("Unable to parse environment {AppEnvironment} into AppEnvironment. Using default AppEnvironment.Dev", appEnvironment);
52-
AddElasticDocsApiUsecases(services, AppEnvironment.Dev);
50+
AddElasticDocsApiUsecases(services, AppEnv.Dev);
5351
}
5452
}
5553

5654

57-
private static void AddElasticDocsApiUsecases(this IServiceCollection services, AppEnvironment appEnvironment)
55+
private static void AddElasticDocsApiUsecases(this IServiceCollection services, AppEnv appEnv)
5856
{
5957
_ = services.ConfigureHttpJsonOptions(options =>
6058
{
6159
options.SerializerOptions.TypeInfoResolverChain.Insert(0, ApiJsonContext.Default);
6260
});
6361
_ = services.AddHttpClient();
64-
AddParameterProvider(services, appEnvironment);
65-
AddAskAiUsecase(services, appEnvironment);
62+
// Register AppEnvironment as a singleton for dependency injection
63+
_ = services.AddSingleton(new AppEnvironment { Current = appEnv });
64+
AddParameterProvider(services, appEnv);
65+
AddAskAiUsecase(services, appEnv);
6666
}
6767

68-
// https://docs.aws.amazon.com/systems-manager/latest/userguide/ps-integration-lambda-extensions.html
69-
private static void AddParameterProvider(IServiceCollection services, AppEnvironment appEnvironment)
68+
// https://docs.aws.amazon.com/systems -manager/latest/userguide/ps-integration-lambda-extensions.html
69+
private static void AddParameterProvider(IServiceCollection services, AppEnv appEnv)
7070
{
7171
var logger = GetLogger(services);
7272

73-
switch (appEnvironment)
73+
switch (appEnv)
7474
{
75-
case AppEnvironment.Prod:
76-
case AppEnvironment.Staging:
77-
case AppEnvironment.Edge:
75+
case AppEnv.Prod:
76+
case AppEnv.Staging:
77+
case AppEnv.Edge:
7878
{
79-
logger?.LogInformation("Configuring LambdaExtensionParameterProvider for environment {AppEnvironment}", appEnvironment);
79+
logger?.LogInformation("Configuring LambdaExtensionParameterProvider for environment {AppEnvironment}", appEnv);
8080
_ = services.AddHttpClient(LambdaExtensionParameterProvider.HttpClientName, client =>
8181
{
8282
client.BaseAddress = new Uri("http://localhost:2773");
@@ -85,39 +85,27 @@ private static void AddParameterProvider(IServiceCollection services, AppEnviron
8585
_ = services.AddSingleton<IParameterProvider, LambdaExtensionParameterProvider>();
8686
break;
8787
}
88-
case AppEnvironment.Dev:
88+
case AppEnv.Dev:
8989
{
90-
logger?.LogInformation("Configuring LocalParameterProvider for environment {AppEnvironment}", appEnvironment);
90+
logger?.LogInformation("Configuring LocalParameterProvider for environment {AppEnvironment}", appEnv);
9191
_ = services.AddSingleton<IParameterProvider, LocalParameterProvider>();
9292
break;
9393
}
9494
default:
9595
{
96-
throw new ArgumentOutOfRangeException(nameof(appEnvironment), appEnvironment,
96+
throw new ArgumentOutOfRangeException(nameof(appEnv), appEnv,
9797
"Unsupported environment for parameter provider.");
9898
}
9999
}
100100
}
101101

102-
private static void AddAskAiUsecase(IServiceCollection services, AppEnvironment appEnvironment)
102+
private static void AddAskAiUsecase(IServiceCollection services, AppEnv appEnv)
103103
{
104104
var logger = GetLogger(services);
105-
logger?.LogInformation("Configuring AskAi use case for environment {AppEnvironment}", appEnvironment);
106-
107-
_ = services.Configure<LlmGatewayOptions>(options =>
108-
{
109-
var serviceProvider = services.BuildServiceProvider();
110-
var parameterProvider = serviceProvider.GetRequiredService<IParameterProvider>();
111-
var appEnvString = appEnvironment.ToStringFast(true);
112-
113-
options.ServiceAccount = parameterProvider.GetParam($"/elastic-docs-v3/{appEnvString}/llm-gateway-service-account").GetAwaiter().GetResult();
114-
options.FunctionUrl = parameterProvider.GetParam($"/elastic-docs-v3/{appEnvString}/llm-gateway-function-url").GetAwaiter().GetResult();
115-
116-
var functionUri = new Uri(options.FunctionUrl);
117-
options.TargetAudience = $"{functionUri.Scheme}://{functionUri.Host}";
118-
});
119-
_ = services.AddScoped<GcpIdTokenProvider>();
120-
_ = services.AddScoped<IAskAiGateway<Stream>, LlmGatewayAskAiGateway>();
105+
logger?.LogInformation("Configuring AskAi use case for environment {AppEnvironment}", appEnv);
106+
_ = services.AddSingleton<GcpIdTokenProvider>();
107+
_ = services.AddSingleton<IAskAiGateway<Stream>, LlmGatewayAskAiGateway>();
108+
_ = services.AddScoped<LlmGatewayOptions>();
121109
_ = services.AddScoped<AskAiUsecase>();
122110
}
123111
}

0 commit comments

Comments
 (0)