Skip to content

Commit 60c6508

Browse files
authored
Update BedrockChatClient with support for Instructions/RawRepresentationFactory (#3906)
* Update BedrockChatClient with support for Instructions/RawRepresentationFactory - ChatOptions.RawRepresentationFactory, to support "breaking glass" if someone wants to set Converse{Stream}Request properties that aren't available via messages/options directly. - ChatOptions.Instructions, as a non-messages way to set a system prompt.
1 parent e3a2f2e commit 60c6508

File tree

6 files changed

+119
-102
lines changed

6 files changed

+119
-102
lines changed

extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.NetFramework.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
</Choose>
3838

3939
<ItemGroup>
40-
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.5.0" />
40+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.7.0" />
4141
</ItemGroup>
4242

4343
<ItemGroup>

extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.NetStandard.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
</Choose>
4242

4343
<ItemGroup>
44-
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.5.0" />
44+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.7.0" />
4545
</ItemGroup>
4646

4747
<ItemGroup>

extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.nuspec

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
<metadata>
44
<id>AWSSDK.Extensions.Bedrock.MEAI</id>
55
<title>AWSSDK - Bedrock integration with Microsoft.Extensions.AI.</title>
6-
<version>4.0.1.1</version>
6+
<version>4.0.2.0</version>
77
<authors>Amazon Web Services</authors>
88
<description>Implementations of Microsoft.Extensions.AI's abstractions for Bedrock.</description>
99
<language>en-US</language>
@@ -15,17 +15,17 @@
1515
<group targetFramework="net472">
1616
<dependency id="AWSSDK.Core" version="4.0.0.4" />
1717
<dependency id="AWSSDK.BedrockRuntime" version="4.0.0.3" />
18-
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.5.0" />
18+
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.7.0" />
1919
</group>
2020
<group targetFramework="netstandard2.0">
2121
<dependency id="AWSSDK.Core" version="4.0.0.4" />
2222
<dependency id="AWSSDK.BedrockRuntime" version="4.0.0.3" />
23-
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.5.0" />
23+
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.7.0" />
2424
</group>
2525
<group targetFramework="net8.0">
2626
<dependency id="AWSSDK.Core" version="4.0.0.4" />
2727
<dependency id="AWSSDK.BedrockRuntime" version="4.0.0.3" />
28-
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.5.0" />
28+
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.7.0" />
2929
</group>
3030
</dependencies>
3131
</metadata>

extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs

Lines changed: 101 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -71,15 +71,13 @@ public async Task<ChatResponse> GetResponseAsync(
7171
throw new ArgumentNullException(nameof(messages));
7272
}
7373

74-
ConverseRequest request = new()
75-
{
76-
ModelId = options?.ModelId ?? _modelId,
77-
Messages = CreateMessages(messages),
78-
System = CreateSystem(messages),
79-
ToolConfig = CreateToolConfig(options),
80-
InferenceConfig = CreateInferenceConfiguration(options),
81-
AdditionalModelRequestFields = CreateAdditionalModelRequestFields(options),
82-
};
74+
ConverseRequest request = options?.RawRepresentationFactory?.Invoke(this) as ConverseRequest ?? new();
75+
request.ModelId ??= options?.ModelId ?? _modelId;
76+
request.Messages = CreateMessages(request.Messages, messages);
77+
request.System = CreateSystem(request.System, messages, options);
78+
request.ToolConfig = CreateToolConfig(request.ToolConfig, options);
79+
request.InferenceConfig = CreateInferenceConfiguration(request.InferenceConfig, options);
80+
request.AdditionalModelRequestFields = CreateAdditionalModelRequestFields(request.AdditionalModelRequestFields, options);
8381

8482
var response = await _runtime.ConverseAsync(request, cancellationToken).ConfigureAwait(false);
8583

@@ -162,15 +160,13 @@ public async IAsyncEnumerable<ChatResponseUpdate> GetStreamingResponseAsync(
162160
throw new ArgumentNullException(nameof(messages));
163161
}
164162

165-
ConverseStreamRequest request = new()
166-
{
167-
ModelId = options?.ModelId ?? _modelId,
168-
Messages = CreateMessages(messages),
169-
System = CreateSystem(messages),
170-
ToolConfig = CreateToolConfig(options),
171-
InferenceConfig = CreateInferenceConfiguration(options),
172-
AdditionalModelRequestFields = CreateAdditionalModelRequestFields(options),
173-
};
163+
ConverseStreamRequest request = options?.RawRepresentationFactory?.Invoke(this) as ConverseStreamRequest ?? new();
164+
request.ModelId ??= options?.ModelId ?? _modelId;
165+
request.Messages = CreateMessages(request.Messages, messages);
166+
request.System = CreateSystem(request.System, messages, options);
167+
request.ToolConfig = CreateToolConfig(request.ToolConfig, options);
168+
request.InferenceConfig = CreateInferenceConfiguration(request.InferenceConfig, options);
169+
request.AdditionalModelRequestFields = CreateAdditionalModelRequestFields(request.AdditionalModelRequestFields, options);
174170

175171
var result = await _runtime.ConverseStreamAsync(request, cancellationToken).ConfigureAwait(false);
176172

@@ -356,11 +352,21 @@ private static ChatFinishReason GetChatFinishReason(StopReason stopReason) =>
356352
};
357353

358354
/// <summary>Creates a list of <see cref="SystemContentBlock"/> from the system messages in the provided <paramref name="messages"/>.</summary>
359-
private static List<SystemContentBlock> CreateSystem(IEnumerable<ChatMessage> messages) =>
360-
messages
355+
private static List<SystemContentBlock> CreateSystem(List<SystemContentBlock>? rawMessages, IEnumerable<ChatMessage> messages, ChatOptions? options)
356+
{
357+
List<SystemContentBlock> system = rawMessages ?? [];
358+
359+
if (options?.Instructions is { } instructions)
360+
{
361+
system.Add(new SystemContentBlock() { Text = instructions });
362+
}
363+
364+
system.AddRange(messages
361365
.Where(m => m.Role == ChatRole.System && m.Contents.Any(c => c is TextContent))
362-
.Select(m => new SystemContentBlock() { Text = string.Concat(m.Contents.OfType<TextContent>()) })
363-
.ToList();
366+
.Select(m => new SystemContentBlock() { Text = string.Concat(m.Contents.OfType<TextContent>()) }));
367+
368+
return system;
369+
}
364370

365371
/// <summary>Parses JSON tool input into a <see cref="Dictionary{String, Object}"/>.</summary>
366372
private static Dictionary<string, object?>? ParseToolInputs(string? jsonInput, out Exception? parseError)
@@ -382,9 +388,9 @@ private static List<SystemContentBlock> CreateSystem(IEnumerable<ChatMessage> me
382388
}
383389

384390
/// <summary>Creates a list of <see cref="Message"/> from the provided <paramref name="chatMessages"/>.</summary>
385-
private static List<Message> CreateMessages(IEnumerable<ChatMessage> chatMessages)
391+
private static List<Message> CreateMessages(List<Message>? rawMessages, IEnumerable<ChatMessage> chatMessages)
386392
{
387-
List<Message> messages = [];
393+
List<Message> messages = rawMessages ?? [];
388394

389395
foreach (ChatMessage chatMessage in chatMessages)
390396
{
@@ -681,100 +687,110 @@ private static Document ToDocument(JsonElement json)
681687
}
682688

683689
/// <summary>Creates an <see cref="ToolConfiguration"/> from the specified options.</summary>
684-
private static ToolConfiguration? CreateToolConfig(ChatOptions? options)
690+
private static ToolConfiguration? CreateToolConfig(ToolConfiguration? toolConfig, ChatOptions? options)
685691
{
686-
List<Tool>? tools = options?.Tools?.OfType<AIFunction>().Select(f =>
692+
if (options?.Tools is { Count: > 0 } tools)
687693
{
688-
Document inputs = default;
689-
List<Document> required = [];
690-
691-
if (f.JsonSchema.TryGetProperty("properties", out JsonElement properties))
694+
foreach (AITool tool in tools)
692695
{
693-
foreach (JsonProperty parameter in properties.EnumerateObject())
696+
if (tool is not AIFunction f)
694697
{
695-
inputs.Add(parameter.Name, ToDocument(parameter.Value));
698+
continue;
696699
}
697-
}
698700

699-
if (f.JsonSchema.TryGetProperty("required", out JsonElement requiredProperties))
700-
{
701-
foreach (JsonElement requiredProperty in requiredProperties.EnumerateArray())
701+
Document inputs = default;
702+
List<Document> required = [];
703+
704+
if (f.JsonSchema.TryGetProperty("properties", out JsonElement properties))
702705
{
703-
required.Add(requiredProperty.GetString());
706+
foreach (JsonProperty parameter in properties.EnumerateObject())
707+
{
708+
inputs.Add(parameter.Name, ToDocument(parameter.Value));
709+
}
704710
}
705-
}
706711

707-
var schemaDictionary = new Dictionary<string, Document>()
708-
{
709-
["type"] = new Document("object"),
710-
};
712+
if (f.JsonSchema.TryGetProperty("required", out JsonElement requiredProperties))
713+
{
714+
foreach (JsonElement requiredProperty in requiredProperties.EnumerateArray())
715+
{
716+
required.Add(requiredProperty.GetString());
717+
}
718+
}
711719

712-
if (inputs != default)
713-
{
714-
schemaDictionary["properties"] = inputs;
715-
}
720+
Dictionary<string, Document> schemaDictionary = new()
721+
{
722+
["type"] = new Document("object"),
723+
};
716724

717-
if (required.Count > 0)
718-
{
719-
schemaDictionary["required"] = new Document(required);
720-
}
725+
if (inputs != default)
726+
{
727+
schemaDictionary["properties"] = inputs;
728+
}
721729

722-
return new Tool()
723-
{
724-
ToolSpec = new ToolSpecification()
730+
if (required.Count > 0)
731+
{
732+
schemaDictionary["required"] = new Document(required);
733+
}
734+
735+
toolConfig ??= new();
736+
toolConfig.Tools ??= [];
737+
toolConfig.Tools.Add(new()
725738
{
726-
Name = f.Name,
727-
Description = !string.IsNullOrEmpty(f.Description) ? f.Description : f.Name,
728-
InputSchema = new()
739+
ToolSpec = new ToolSpecification()
729740
{
730-
Json = new(schemaDictionary)
741+
Name = f.Name,
742+
Description = !string.IsNullOrEmpty(f.Description) ? f.Description : f.Name,
743+
InputSchema = new()
744+
{
745+
Json = new(schemaDictionary)
746+
},
731747
},
732-
},
733-
};
734-
}).ToList();
748+
});
749+
}
750+
}
735751

736-
ToolChoice? choice = null;
737-
if (tools is { Count: > 0 })
752+
if (toolConfig?.Tools is { Count: > 0 } && toolConfig.ToolChoice is null)
738753
{
739754
switch (options!.ToolMode)
740755
{
741-
case AutoChatToolMode:
742-
case null:
743-
choice = new ToolChoice() { Auto = new() };
744-
break;
745-
746756
case RequiredChatToolMode r:
747-
choice = !string.IsNullOrWhiteSpace(r.RequiredFunctionName) ?
757+
toolConfig.ToolChoice = !string.IsNullOrWhiteSpace(r.RequiredFunctionName) ?
748758
new ToolChoice() { Tool = new() { Name = r.RequiredFunctionName } } :
749759
new ToolChoice() { Any = new() };
750760
break;
751761
}
752-
753-
return new()
754-
{
755-
ToolChoice = choice,
756-
Tools = tools,
757-
};
758762
}
759763

760-
return null;
764+
return toolConfig;
761765
}
762766

763767
/// <summary>Creates an <see cref="InferenceConfiguration"/> from the specified options.</summary>
764-
private static InferenceConfiguration CreateInferenceConfiguration(ChatOptions? options) =>
765-
new()
768+
private static InferenceConfiguration CreateInferenceConfiguration(InferenceConfiguration config, ChatOptions? options)
769+
{
770+
config ??= new();
771+
772+
config.MaxTokens ??= options?.MaxOutputTokens;
773+
config.Temperature ??= options?.Temperature;
774+
config.TopP ??= options?.TopP;
775+
776+
if (options?.StopSequences is { Count: > 0 } stopOptions)
766777
{
767-
MaxTokens = options?.MaxOutputTokens,
768-
StopSequences = options?.StopSequences?.ToList(),
769-
Temperature = options?.Temperature,
770-
TopP = options?.TopP,
771-
};
778+
if (config.StopSequences is null)
779+
{
780+
config.StopSequences = stopOptions.ToList();
781+
}
782+
else
783+
{
784+
config.StopSequences.AddRange(stopOptions);
785+
}
786+
}
787+
788+
return config;
789+
}
772790

773791
/// <summary>Creates a <see cref="Document"/> from the specified options to use as the additional model request options.</summary>
774-
private static Document CreateAdditionalModelRequestFields(ChatOptions? options)
792+
private static Document CreateAdditionalModelRequestFields(Document d, ChatOptions? options)
775793
{
776-
Document d = default;
777-
778794
if (options is not null)
779795
{
780796
if (options.TopK is int topK)

extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockEmbeddingGenerator.cs

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
* permissions and limitations under the License.
1414
*/
1515

16+
using Amazon.BedrockRuntime.Model;
1617
using Microsoft.Extensions.AI;
1718
using System;
1819
using System.Collections.Generic;
@@ -89,17 +90,17 @@ public async Task<GeneratedEmbeddings<Embedding<float>>> GenerateAsync(
8990

9091
foreach (string value in values)
9192
{
92-
var response = await _runtime.InvokeModelAsync(new()
93+
InvokeModelRequest request = options?.RawRepresentationFactory?.Invoke(this) as InvokeModelRequest ?? new();
94+
request.ModelId ??= options?.ModelId ?? _modelId;
95+
request.Accept ??= "application/json";
96+
request.ContentType ??= "application/json";
97+
request.Body ??= new MemoryStream(JsonSerializer.SerializeToUtf8Bytes(new()
9398
{
94-
ModelId = options?.ModelId ?? _modelId,
95-
Accept = "application/json",
96-
ContentType = "application/json",
97-
Body = new MemoryStream(JsonSerializer.SerializeToUtf8Bytes(new EmbeddingRequest()
98-
{
99-
InputText = value,
100-
Dimensions = options?.Dimensions ?? _dimensions,
101-
}, BedrockJsonContext.Default.EmbeddingRequest)),
102-
}, cancellationToken).ConfigureAwait(false);
99+
InputText = value,
100+
Dimensions = options?.Dimensions ?? _dimensions,
101+
}, BedrockJsonContext.Default.EmbeddingRequest));
102+
103+
var response = await _runtime.InvokeModelAsync(request, cancellationToken).ConfigureAwait(false);
103104

104105
var er = JsonSerializer.Deserialize(response.Body, BedrockJsonContext.Default.EmbeddingResponse);
105106
if (er?.Embedding is not null)

extensions/test/BedrockMEAITests/BedrockMEAITests.NetFramework.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
</PropertyGroup>
1919

2020
<ItemGroup>
21-
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.5.0" />
21+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.7.0" />
2222
<PackageReference Include="xunit" Version="2.9.2" />
2323
<PackageReference Include="xunit.runner.visualstudio" Version="2.8.2" />
2424
</ItemGroup>

0 commit comments

Comments
 (0)