diff --git a/src/api/providers/__tests__/bedrock-custom-arn.test.ts b/src/api/providers/__tests__/bedrock-custom-arn.test.ts index f7dc2870fa4..b5522180f9d 100644 --- a/src/api/providers/__tests__/bedrock-custom-arn.test.ts +++ b/src/api/providers/__tests__/bedrock-custom-arn.test.ts @@ -3,10 +3,20 @@ import { ApiHandlerOptions } from "../../../shared/api" // Mock the AWS SDK jest.mock("@aws-sdk/client-bedrock-runtime", () => { + const mockResponse = { + output: { + message: { + content: [ + { + text: "Test response", + }, + ], + }, + }, + } + const mockSend = jest.fn().mockImplementation(() => { - return Promise.resolve({ - output: new TextEncoder().encode(JSON.stringify({ content: "Test response" })), - }) + return Promise.resolve(mockResponse) }) return { diff --git a/src/api/providers/__tests__/bedrock.test.ts b/src/api/providers/__tests__/bedrock.test.ts index 74249ac820b..facba103fd8 100644 --- a/src/api/providers/__tests__/bedrock.test.ts +++ b/src/api/providers/__tests__/bedrock.test.ts @@ -399,14 +399,20 @@ describe("AwsBedrockHandler", () => { }) }) + //response.output.message.content[0].text + describe("completePrompt", () => { it("should complete prompt successfully", async () => { const mockResponse = { - output: new TextEncoder().encode( - JSON.stringify({ - content: "Test response", - }), - ), + output: { + message: { + content: [ + { + text: "Test response", + }, + ], + }, + }, } const mockSend = jest.fn().mockResolvedValue(mockResponse) @@ -450,7 +456,9 @@ describe("AwsBedrockHandler", () => { it("should handle invalid response format", async () => { const mockResponse = { - output: new TextEncoder().encode("invalid json"), + output: { + message: {}, + }, } const mockSend = jest.fn().mockResolvedValue(mockResponse) @@ -464,9 +472,16 @@ describe("AwsBedrockHandler", () => { it("should handle empty response", async () => { const mockResponse = { - output: new TextEncoder().encode(JSON.stringify({})), + output: { + message: { + content: [ + { + text: "", + }, + ], + }, + }, } - const mockSend = jest.fn().mockResolvedValue(mockResponse) handler["client"] = { send: mockSend, @@ -486,11 +501,15 @@ describe("AwsBedrockHandler", () => { }) const mockResponse = { - output: new TextEncoder().encode( - JSON.stringify({ - content: "Test response", - }), - ), + output: { + message: { + content: [ + { + text: "Test response", + }, + ], + }, + }, } const mockSend = jest.fn().mockResolvedValue(mockResponse) @@ -519,11 +538,15 @@ describe("AwsBedrockHandler", () => { }) const mockResponse = { - output: new TextEncoder().encode( - JSON.stringify({ - content: "Test response", - }), - ), + output: { + message: { + content: [ + { + text: "Test response", + }, + ], + }, + }, } const mockSend = jest.fn().mockResolvedValue(mockResponse) @@ -552,13 +575,16 @@ describe("AwsBedrockHandler", () => { }) const mockResponse = { - output: new TextEncoder().encode( - JSON.stringify({ - content: "Test response", - }), - ), + output: { + message: { + content: [ + { + text: "Test response", + }, + ], + }, + }, } - const mockSend = jest.fn().mockResolvedValue(mockResponse) handler["client"] = { send: mockSend, @@ -585,11 +611,15 @@ describe("AwsBedrockHandler", () => { }) const mockResponse = { - output: new TextEncoder().encode( - JSON.stringify({ - content: "Test response", - }), - ), + output: { + message: { + content: [ + { + text: "Test response", + }, + ], + }, + }, } const mockSend = jest.fn().mockResolvedValue(mockResponse) diff --git a/src/api/providers/bedrock.ts b/src/api/providers/bedrock.ts index 374b81d8a98..1b14bb7d7b5 100644 --- a/src/api/providers/bedrock.ts +++ b/src/api/providers/bedrock.ts @@ -665,13 +665,14 @@ Please check: const command = new ConverseCommand(payload) const response = await this.client.send(command) - if (response.output && response.output instanceof Uint8Array) { + if ( + response?.output?.message?.content && + response.output.message.content.length > 0 && + response.output.message.content[0].text && + response.output.message.content[0].text.trim().length > 0 + ) { try { - const outputStr = new TextDecoder().decode(response.output) - const output = JSON.parse(outputStr) - if (output.content) { - return output.content - } + return response.output.message.content[0].text } catch (parseError) { logger.error("Failed to parse Bedrock response", { ctx: "bedrock",