Skip to content

Commit 3ac5883

Browse files
committed
Add support for converse command
1 parent 4a2c81f commit 3ac5883

File tree

1 file changed

+75
-13
lines changed

1 file changed

+75
-13
lines changed

library/sinks/AwsSDKVersion3.ts

Lines changed: 75 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,51 @@ import { wrapExport } from "../agent/hooks/wrapExport";
44
import { Wrapper } from "../agent/Wrapper";
55
import { isPlainObject } from "../helpers/isPlainObject";
66

7-
type Usage = {
7+
type InvokeUsage = {
88
input_tokens?: number;
99
output_tokens?: number;
1010
};
1111

12-
function isUsage(usage: unknown): usage is Usage {
12+
function isUsage(usage: unknown): usage is InvokeUsage {
1313
return (
1414
isPlainObject(usage) &&
1515
typeof usage.input_tokens === "number" &&
1616
typeof usage.output_tokens === "number"
1717
);
1818
}
1919

20-
type Response = {
20+
type InvokeResponse = {
2121
body?: Uint8Array;
2222
};
2323

24-
function isResponse(response: unknown): response is Response {
24+
function isInvokeResponse(response: unknown): response is InvokeResponse {
25+
return isPlainObject(response);
26+
}
27+
28+
type ConverseUsage = {
29+
inputTokens?: number;
30+
outputTokens?: number;
31+
};
32+
33+
function isConverseUsage(usage: unknown): usage is ConverseUsage {
34+
return (
35+
isPlainObject(usage) &&
36+
typeof usage.inputTokens === "number" &&
37+
typeof usage.outputTokens === "number"
38+
);
39+
}
40+
41+
type ConverseResponse = {
42+
usage?: ConverseUsage;
43+
};
44+
45+
function isConverseResponse(response: unknown): response is ConverseResponse {
2546
return isPlainObject(response);
2647
}
2748

2849
export class AwsSDKVersion3 implements Wrapper {
29-
private processResponse(response: unknown, agent: Agent) {
30-
if (!isResponse(response)) {
50+
private processInvokeModelResponse(response: unknown, agent: Agent) {
51+
if (!isInvokeResponse(response)) {
3152
return;
3253
}
3354

@@ -61,25 +82,66 @@ export class AwsSDKVersion3 implements Wrapper {
6182
}
6283
}
6384

85+
private processConverseResponse(
86+
response: unknown,
87+
command: unknown,
88+
agent: Agent
89+
) {
90+
// @ts-expect-error We don't know the type of command
91+
if (!command || !command.input || !command.input.modelId) {
92+
return;
93+
}
94+
95+
if (!isConverseResponse(response)) {
96+
return;
97+
}
98+
99+
// @ts-expect-error We don't know the type of command
100+
const modalId: string = command.input.modelId;
101+
102+
let inputTokens = 0;
103+
let outputTokens = 0;
104+
105+
if (isConverseUsage(response.usage)) {
106+
inputTokens = response.usage.inputTokens || 0;
107+
outputTokens = response.usage.outputTokens || 0;
108+
}
109+
110+
const aiStats = agent.getAIStatistics();
111+
aiStats.onAICall({
112+
provider: "bedrock",
113+
model: modalId,
114+
inputTokens: inputTokens,
115+
outputTokens: outputTokens,
116+
});
117+
}
118+
64119
wrap(hooks: Hooks) {
65-
// Note: Converse command is not supported yet
66120
hooks
67121
.addPackage("@aws-sdk/client-bedrock-runtime")
68122
.withVersion("^3.0.0")
69123
.onRequire((exports, pkgInfo) => {
70-
if (exports.BedrockRuntimeClient && exports.InvokeModelCommand) {
124+
if (exports.BedrockRuntimeClient) {
71125
wrapExport(exports.BedrockRuntimeClient.prototype, "send", pkgInfo, {
72126
kind: "ai_op",
73127
modifyReturnValue: (args, returnValue, agent) => {
74128
if (args.length > 0) {
75-
if (
76-
returnValue instanceof Promise &&
77-
args[0] instanceof exports.InvokeModelCommand
78-
) {
129+
const command = args[0];
130+
if (returnValue instanceof Promise) {
79131
// Inspect the response after the promise resolves, it won't change the original promise
80132
returnValue.then((response) => {
81133
try {
82-
this.processResponse(response, agent);
134+
if (
135+
exports.InvokeModelCommand &&
136+
command instanceof exports.InvokeModelCommand
137+
) {
138+
this.processInvokeModelResponse(response, agent);
139+
} else if (
140+
exports.ConverseCommand &&
141+
command instanceof exports.ConverseCommand
142+
) {
143+
this.processConverseResponse(response, command, agent);
144+
}
83145
} catch {
84146
// If we don't catch these errors, it will result in an unhandled promise rejection!
85147
}

0 commit comments

Comments
 (0)