Skip to content

Commit 41663ab

Browse files
committed
feat: add support for new cohere command r models
1 parent e8c96ae commit 41663ab

File tree

2 files changed

+78
-3
lines changed

2 files changed

+78
-3
lines changed

aws-distro-opentelemetry-node-autoinstrumentation/src/patches/aws/services/bedrock.ts

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,19 @@ export class BedrockRuntimeServiceExtension implements ServiceExtension {
245245
if (requestBody.top_p !== undefined) {
246246
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TOP_P] = requestBody.top_p;
247247
}
248+
} else if (modelId.includes('cohere.command-r')) {
249+
if (requestBody.max_tokens !== undefined) {
250+
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MAX_TOKENS] = requestBody.max_tokens;
251+
}
252+
if (requestBody.temperature !== undefined) {
253+
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TEMPERATURE] = requestBody.temperature;
254+
}
255+
if (requestBody.p !== undefined) {
256+
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TOP_P] = requestBody.p;
257+
}
258+
if (requestBody.message !== undefined) {
259+
spanAttributes[AwsSpanProcessingUtil.GEN_AI_USAGE_INPUT_TOKENS] = Math.ceil(requestBody.message.length / 6);
260+
}
248261
} else if (modelId.includes('cohere.command')) {
249262
if (requestBody.max_tokens !== undefined) {
250263
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MAX_TOKENS] = requestBody.max_tokens;
@@ -255,6 +268,9 @@ export class BedrockRuntimeServiceExtension implements ServiceExtension {
255268
if (requestBody.p !== undefined) {
256269
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TOP_P] = requestBody.p;
257270
}
271+
if (requestBody.prompt !== undefined) {
272+
spanAttributes[AwsSpanProcessingUtil.GEN_AI_USAGE_INPUT_TOKENS] = Math.ceil(requestBody.prompt.length / 6);
273+
}
258274
} else if (modelId.includes('ai21.jamba')) {
259275
if (requestBody.max_tokens !== undefined) {
260276
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MAX_TOKENS] = requestBody.max_tokens;
@@ -329,13 +345,18 @@ export class BedrockRuntimeServiceExtension implements ServiceExtension {
329345
if (responseBody.stop_reason !== undefined) {
330346
span.setAttribute(AwsSpanProcessingUtil.GEN_AI_RESPONSE_FINISH_REASONS, [responseBody.stop_reason]);
331347
}
332-
} else if (currentModelId.includes('cohere.command')) {
333-
if (responseBody.prompt !== undefined) {
348+
} else if (currentModelId.includes('cohere.command-r')) {
349+
console.log('Response Body:', responseBody);
350+
if (responseBody.text !== undefined) {
334351
// NOTE: We approximate the token count since this value is not directly available in the body
335352
// According to Bedrock docs they use (total_chars / 6) to approximate token count for pricing.
336353
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-prepare.html
337-
span.setAttribute(AwsSpanProcessingUtil.GEN_AI_USAGE_INPUT_TOKENS, Math.ceil(responseBody.prompt.length / 6));
354+
span.setAttribute(AwsSpanProcessingUtil.GEN_AI_USAGE_OUTPUT_TOKENS, Math.ceil(responseBody.text.length / 6));
338355
}
356+
if (responseBody.finish_reason !== undefined) {
357+
span.setAttribute(AwsSpanProcessingUtil.GEN_AI_RESPONSE_FINISH_REASONS, [responseBody.finish_reason]);
358+
}
359+
} else if (currentModelId.includes('cohere.command')) {
339360
if (responseBody.generations?.[0]?.text !== undefined) {
340361
span.setAttribute(
341362
AwsSpanProcessingUtil.GEN_AI_USAGE_OUTPUT_TOKENS,

aws-distro-opentelemetry-node-autoinstrumentation/test/patches/aws/services/bedrock.test.ts

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,60 @@ describe('BedrockRuntime', () => {
517517
expect(invokeModelSpan.kind).toBe(SpanKind.CLIENT);
518518
});
519519

520+
it('Add Cohere Command R model attributes to span', async () => {
521+
const modelId: string = 'cohere.command-r-v1:0"';
522+
const prompt: string = "Describe the purpose of a 'hello world' program in one line";
523+
const nativeRequest: any = {
524+
message: prompt,
525+
max_tokens: 512,
526+
temperature: 0.5,
527+
p: 0.65,
528+
};
529+
const mockRequestBody: string = JSON.stringify(nativeRequest);
530+
const mockResponseBody: any = {
531+
finish_reason: 'COMPLETE',
532+
text: 'test-generation-text',
533+
prompt: prompt,
534+
request: {
535+
commandInput: {
536+
modelId: modelId,
537+
},
538+
},
539+
};
540+
541+
nock(`https://bedrock-runtime.${region}.amazonaws.com`)
542+
.post(`/model/${encodeURIComponent(modelId)}/invoke`)
543+
.reply(200, mockResponseBody);
544+
545+
await bedrock
546+
.invokeModel({
547+
modelId: modelId,
548+
body: mockRequestBody,
549+
})
550+
.catch((err: any) => {
551+
console.log('error', err);
552+
});
553+
554+
const testSpans: ReadableSpan[] = getTestSpans();
555+
const invokeModelSpans: ReadableSpan[] = testSpans.filter((s: ReadableSpan) => {
556+
return s.name === 'BedrockRuntime.InvokeModel';
557+
});
558+
expect(invokeModelSpans.length).toBe(1);
559+
const invokeModelSpan = invokeModelSpans[0];
560+
expect(invokeModelSpan.attributes[AWS_ATTRIBUTE_KEYS.AWS_BEDROCK_AGENT_ID]).toBeUndefined();
561+
expect(invokeModelSpan.attributes[AWS_ATTRIBUTE_KEYS.AWS_BEDROCK_KNOWLEDGE_BASE_ID]).toBeUndefined();
562+
expect(invokeModelSpan.attributes[AWS_ATTRIBUTE_KEYS.AWS_BEDROCK_DATA_SOURCE_ID]).toBeUndefined();
563+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_SYSTEM]).toBe('aws_bedrock');
564+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MODEL]).toBe(modelId);
565+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MAX_TOKENS]).toBe(512);
566+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TEMPERATURE]).toBe(0.5);
567+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TOP_P]).toBe(0.65);
568+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_USAGE_INPUT_TOKENS]).toBe(10);
569+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_USAGE_OUTPUT_TOKENS]).toBe(4);
570+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_RESPONSE_FINISH_REASONS]).toEqual(['COMPLETE']);
571+
expect(invokeModelSpan.kind).toBe(SpanKind.CLIENT);
572+
});
573+
520574
it('Add Meta Llama model attributes to span', async () => {
521575
const modelId: string = 'meta.llama2-13b-chat-v1';
522576
const prompt: string = 'Describe the purpose of an interpreter program in one line.';

0 commit comments

Comments
 (0)