Skip to content

Commit a82cd0d

Browse files
committed
Update IChatClient with support from latest bedrock runtime / M.E.AI
- Adds support for multi-modal tool returns. - Adds support for citations with URIs. - Adds a ton of tests verifying IChatClient behavior around the underlying IAmazonBedrockRuntime.
1 parent f88844e commit a82cd0d

File tree

7 files changed

+3511
-627
lines changed

7 files changed

+3511
-627
lines changed

extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.NetFramework.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
</Choose>
3838

3939
<ItemGroup>
40-
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.9.1" />
40+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="10.1.1" />
4141
</ItemGroup>
4242

4343
<ItemGroup>

extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.NetStandard.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
</Choose>
4242

4343
<ItemGroup>
44-
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="9.9.1" />
44+
<PackageReference Include="Microsoft.Extensions.AI.Abstractions" Version="10.1.1" />
4545
</ItemGroup>
4646

4747
<ItemGroup>

extensions/src/AWSSDK.Extensions.Bedrock.MEAI/AWSSDK.Extensions.Bedrock.MEAI.nuspec

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,17 +15,17 @@
1515
<group targetFramework="net472">
1616
<dependency id="AWSSDK.Core" version="4.0.3.6" />
1717
<dependency id="AWSSDK.BedrockRuntime" version="4.0.14.3" />
18-
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.9.1" />
18+
<dependency id="Microsoft.Extensions.AI.Abstractions" version="10.1.1" />
1919
</group>
2020
<group targetFramework="netstandard2.0">
2121
<dependency id="AWSSDK.Core" version="4.0.3.6" />
2222
<dependency id="AWSSDK.BedrockRuntime" version="4.0.14.3" />
23-
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.9.1" />
23+
<dependency id="Microsoft.Extensions.AI.Abstractions" version="10.1.1" />
2424
</group>
2525
<group targetFramework="net8.0">
2626
<dependency id="AWSSDK.Core" version="4.0.3.6" />
2727
<dependency id="AWSSDK.BedrockRuntime" version="4.0.14.3" />
28-
<dependency id="Microsoft.Extensions.AI.Abstractions" version="9.9.1" />
28+
<dependency id="Microsoft.Extensions.AI.Abstractions" version="10.1.1" />
2929
</group>
3030
</dependencies>
3131
</metadata>

extensions/src/AWSSDK.Extensions.Bedrock.MEAI/BedrockChatClient.cs

Lines changed: 45 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,9 @@ public async Task<ChatResponse> GetResponseAsync(
163163
TextContent tc = new(citations.Content[i]?.Text) { RawRepresentation = citations.Content[i] };
164164
tc.Annotations = [new CitationAnnotation()
165165
{
166+
Snippet = citations.Citations[i].SourceContent?.Select(c => c.Text).FirstOrDefault() ?? citations.Citations[i].Source,
166167
Title = citations.Citations[i].Title,
167-
Snippet = citations.Citations[i].SourceContent?.Select(c => c.Text).FirstOrDefault(),
168+
Url = Uri.TryCreate(citations.Citations[i].Location?.Web?.Url, UriKind.Absolute, out Uri? uri) ? uri : null,
168169
}];
169170
result.Contents.Add(tc);
170171
}
@@ -424,15 +425,11 @@ private static UsageDetails CreateUsageDetails(TokenUsage usage)
424425
UsageDetails ud = new()
425426
{
426427
InputTokenCount = usage.InputTokens,
428+
CachedInputTokenCount = usage.CacheReadInputTokens,
427429
OutputTokenCount = usage.OutputTokens,
428430
TotalTokenCount = usage.TotalTokens,
429431
};
430432

431-
if (usage.CacheReadInputTokens is int cacheReadTokens)
432-
{
433-
(ud.AdditionalCounts ??= []).Add(nameof(usage.CacheReadInputTokens), cacheReadTokens);
434-
}
435-
436433
if (usage.CacheWriteInputTokens is int cacheWriteTokens)
437434
{
438435
(ud.AdditionalCounts ??= []).Add(nameof(usage.CacheWriteInputTokens), cacheWriteTokens);
@@ -467,8 +464,7 @@ private static List<SystemContentBlock> CreateSystem(List<SystemContentBlock>? r
467464
});
468465
}
469466

470-
foreach (var message in messages
471-
.Where(m => m.Role == ChatRole.System && m.Contents.Any(c => c is TextContent)))
467+
foreach (var message in messages.Where(m => m.Role == ChatRole.System && m.Contents.Any(c => c is TextContent)))
472468
{
473469
system.Add(new SystemContentBlock()
474470
{
@@ -569,6 +565,10 @@ private static List<ContentBlock> CreateContents(ChatMessage message)
569565
{
570566
switch (content)
571567
{
568+
case AIContent when content.RawRepresentation is ContentBlock cb:
569+
contents.Add(cb);
570+
break;
571+
572572
case TextContent tc:
573573
if (message.Role == ChatRole.Assistant)
574574
{
@@ -651,32 +651,54 @@ private static List<ContentBlock> CreateContents(ChatMessage message)
651651
break;
652652

653653
case FunctionResultContent frc:
654-
Document result = frc.Result switch
655-
{
656-
int i => i,
657-
long l => l,
658-
float f => f,
659-
double d => d,
660-
string s => s,
661-
bool b => b,
662-
JsonElement json => ToDocument(json),
663-
{ } other => ToDocument(JsonSerializer.SerializeToElement(other, BedrockJsonContext.DefaultOptions.GetTypeInfo(other.GetType()))),
664-
_ => default,
665-
};
666-
667654
contents.Add(new()
668655
{
669656
ToolResult = new()
670657
{
671658
ToolUseId = frc.CallId,
672-
Content = [new() { Json = new Document(new Dictionary<string, Document>() { ["result"] = result }) }],
659+
Content = ToToolResultContentBlocks(frc.Result),
673660
},
674661
});
675662
break;
676663
}
677664

665+
static List<ToolResultContentBlock> ToToolResultContentBlocks(object? result) =>
666+
result switch
667+
{
668+
AIContent aic => [ToolResultContentBlockFromAIContent(aic)],
669+
IEnumerable<AIContent> aics => [.. aics.Select(ToolResultContentBlockFromAIContent)],
670+
string s => [new () { Text = s }],
671+
_ => [new()
672+
{
673+
Json = new Document(new Dictionary<string, Document>()
674+
{
675+
["result"] = result switch
676+
{
677+
int i => i,
678+
long l => l,
679+
float f => f,
680+
double d => d,
681+
bool b => b,
682+
JsonElement json => ToDocument(json),
683+
{ } other => ToDocument(JsonSerializer.SerializeToElement(other, BedrockJsonContext.DefaultOptions.GetTypeInfo(other.GetType()))),
684+
_ => default,
685+
}
686+
})
687+
}],
688+
};
689+
690+
static ToolResultContentBlock ToolResultContentBlockFromAIContent(AIContent aic) =>
691+
aic switch
692+
{
693+
TextContent tc => new() { Text = tc.Text },
694+
TextReasoningContent trc => new() { Text = trc.Text },
695+
DataContent dc when GetImageFormat(dc.MediaType) is { } imageFormat => new() { Image = new() { Source = new() { Bytes = new(dc.Data.ToArray()) }, Format = imageFormat } },
696+
DataContent dc when GetVideoFormat(dc.MediaType) is { } videoFormat => new() { Video = new() { Source = new() { Bytes = new(dc.Data.ToArray()) }, Format = videoFormat } },
697+
DataContent dc when GetDocumentFormat(dc.MediaType) is { } docFormat => new() { Document = new() { Source = new() { Bytes = new(dc.Data.ToArray()) }, Format = docFormat, Name = dc.Name ?? "file" } },
698+
_ => ToToolResultContentBlocks(JsonSerializer.SerializeToElement(aic, BedrockJsonContext.DefaultOptions.GetTypeInfo(typeof(object)))).First(),
699+
};
678700

679-
if (content.AdditionalProperties?.TryGetValue(nameof(ContentBlock.CachePoint), out var maybeCachePoint) == true)
701+
if (content.AdditionalProperties?.TryGetValue(nameof(ContentBlock.CachePoint), out var maybeCachePoint) is true)
680702
{
681703
if (maybeCachePoint is CachePointBlock cachePointBlock)
682704
{

0 commit comments

Comments
 (0)