Skip to content

Commit 25fa5e9

Browse files
yiyuan-hemxiamxia
andauthored
feat: add support for new cohere command r models (#118)
*Description of changes:* Adding support for Cohere Command R models. The previous Cohere Command models are not yet fully deprecated ([EOL April 2025](https://docs.aws.amazon.com/bedrock/latest/userguide/model-lifecycle.html)) so we still include support for now. Beginning 11/05/24 - Calls to old Cohere Command models now throw an exception for deprecation. I wasn't able to find any official announcement for this change, but I noticed it while testing during development in the Java SDK. Interestingly, calls to the old model still return a response so the full gen ai attributes are still generated for the time being. ![Screenshot 2024-11-05 at 5 01 11 PM](https://github.com/user-attachments/assets/52c2d30e-5c75-431d-8b14-85fde9938cab) *Test Plan:* Verified the attributes for the Command R model is being generated with sample app auto-instrumentation. ![Screenshot 2024-11-05 at 4 52 36 PM](https://github.com/user-attachments/assets/81f076cc-7945-421d-8c9f-1e754d7acced) By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the terms of your choice. --------- Co-authored-by: Min Xia <[email protected]>
1 parent 6fe7d4d commit 25fa5e9

File tree

2 files changed

+82
-5
lines changed

2 files changed

+82
-5
lines changed

aws-distro-opentelemetry-node-autoinstrumentation/src/patches/aws/services/bedrock.ts

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,22 @@ export class BedrockRuntimeServiceExtension implements ServiceExtension {
245245
if (requestBody.top_p !== undefined) {
246246
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TOP_P] = requestBody.top_p;
247247
}
248+
} else if (modelId.includes('cohere.command-r')) {
249+
if (requestBody.max_tokens !== undefined) {
250+
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MAX_TOKENS] = requestBody.max_tokens;
251+
}
252+
if (requestBody.temperature !== undefined) {
253+
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TEMPERATURE] = requestBody.temperature;
254+
}
255+
if (requestBody.p !== undefined) {
256+
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TOP_P] = requestBody.p;
257+
}
258+
if (requestBody.message !== undefined) {
259+
// NOTE: We approximate the token count since this value is not directly available in the body
260+
// According to Bedrock docs they use (total_chars / 6) to approximate token count for pricing.
261+
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-prepare.html
262+
spanAttributes[AwsSpanProcessingUtil.GEN_AI_USAGE_INPUT_TOKENS] = Math.ceil(requestBody.message.length / 6);
263+
}
248264
} else if (modelId.includes('cohere.command')) {
249265
if (requestBody.max_tokens !== undefined) {
250266
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MAX_TOKENS] = requestBody.max_tokens;
@@ -255,6 +271,9 @@ export class BedrockRuntimeServiceExtension implements ServiceExtension {
255271
if (requestBody.p !== undefined) {
256272
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TOP_P] = requestBody.p;
257273
}
274+
if (requestBody.prompt !== undefined) {
275+
spanAttributes[AwsSpanProcessingUtil.GEN_AI_USAGE_INPUT_TOKENS] = Math.ceil(requestBody.prompt.length / 6);
276+
}
258277
} else if (modelId.includes('ai21.jamba')) {
259278
if (requestBody.max_tokens !== undefined) {
260279
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MAX_TOKENS] = requestBody.max_tokens;
@@ -265,7 +284,7 @@ export class BedrockRuntimeServiceExtension implements ServiceExtension {
265284
if (requestBody.top_p !== undefined) {
266285
spanAttributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TOP_P] = requestBody.top_p;
267286
}
268-
} else if (modelId.includes('mistral.mistral')) {
287+
} else if (modelId.includes('mistral')) {
269288
if (requestBody.prompt !== undefined) {
270289
// NOTE: We approximate the token count since this value is not directly available in the body
271290
// According to Bedrock docs they use (total_chars / 6) to approximate token count for pricing.
@@ -329,13 +348,17 @@ export class BedrockRuntimeServiceExtension implements ServiceExtension {
329348
if (responseBody.stop_reason !== undefined) {
330349
span.setAttribute(AwsSpanProcessingUtil.GEN_AI_RESPONSE_FINISH_REASONS, [responseBody.stop_reason]);
331350
}
332-
} else if (currentModelId.includes('cohere.command')) {
333-
if (responseBody.prompt !== undefined) {
351+
} else if (currentModelId.includes('cohere.command-r')) {
352+
if (responseBody.text !== undefined) {
334353
// NOTE: We approximate the token count since this value is not directly available in the body
335354
// According to Bedrock docs they use (total_chars / 6) to approximate token count for pricing.
336355
// https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-prepare.html
337-
span.setAttribute(AwsSpanProcessingUtil.GEN_AI_USAGE_INPUT_TOKENS, Math.ceil(responseBody.prompt.length / 6));
356+
span.setAttribute(AwsSpanProcessingUtil.GEN_AI_USAGE_OUTPUT_TOKENS, Math.ceil(responseBody.text.length / 6));
338357
}
358+
if (responseBody.finish_reason !== undefined) {
359+
span.setAttribute(AwsSpanProcessingUtil.GEN_AI_RESPONSE_FINISH_REASONS, [responseBody.finish_reason]);
360+
}
361+
} else if (currentModelId.includes('cohere.command')) {
339362
if (responseBody.generations?.[0]?.text !== undefined) {
340363
span.setAttribute(
341364
AwsSpanProcessingUtil.GEN_AI_USAGE_OUTPUT_TOKENS,
@@ -362,7 +385,7 @@ export class BedrockRuntimeServiceExtension implements ServiceExtension {
362385
responseBody.choices[0].finish_reason,
363386
]);
364387
}
365-
} else if (currentModelId.includes('mistral.mistral')) {
388+
} else if (currentModelId.includes('mistral')) {
366389
if (responseBody.outputs?.[0]?.text !== undefined) {
367390
span.setAttribute(
368391
AwsSpanProcessingUtil.GEN_AI_USAGE_OUTPUT_TOKENS,

aws-distro-opentelemetry-node-autoinstrumentation/test/patches/aws/services/bedrock.test.ts

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -517,6 +517,60 @@ describe('BedrockRuntime', () => {
517517
expect(invokeModelSpan.kind).toBe(SpanKind.CLIENT);
518518
});
519519

520+
it('Add Cohere Command R model attributes to span', async () => {
521+
const modelId: string = 'cohere.command-r-v1:0"';
522+
const prompt: string = "Describe the purpose of a 'hello world' program in one line";
523+
const nativeRequest: any = {
524+
message: prompt,
525+
max_tokens: 512,
526+
temperature: 0.5,
527+
p: 0.65,
528+
};
529+
const mockRequestBody: string = JSON.stringify(nativeRequest);
530+
const mockResponseBody: any = {
531+
finish_reason: 'COMPLETE',
532+
text: 'test-generation-text',
533+
prompt: prompt,
534+
request: {
535+
commandInput: {
536+
modelId: modelId,
537+
},
538+
},
539+
};
540+
541+
nock(`https://bedrock-runtime.${region}.amazonaws.com`)
542+
.post(`/model/${encodeURIComponent(modelId)}/invoke`)
543+
.reply(200, mockResponseBody);
544+
545+
await bedrock
546+
.invokeModel({
547+
modelId: modelId,
548+
body: mockRequestBody,
549+
})
550+
.catch((err: any) => {
551+
console.log('error', err);
552+
});
553+
554+
const testSpans: ReadableSpan[] = getTestSpans();
555+
const invokeModelSpans: ReadableSpan[] = testSpans.filter((s: ReadableSpan) => {
556+
return s.name === 'BedrockRuntime.InvokeModel';
557+
});
558+
expect(invokeModelSpans.length).toBe(1);
559+
const invokeModelSpan = invokeModelSpans[0];
560+
expect(invokeModelSpan.attributes[AWS_ATTRIBUTE_KEYS.AWS_BEDROCK_AGENT_ID]).toBeUndefined();
561+
expect(invokeModelSpan.attributes[AWS_ATTRIBUTE_KEYS.AWS_BEDROCK_KNOWLEDGE_BASE_ID]).toBeUndefined();
562+
expect(invokeModelSpan.attributes[AWS_ATTRIBUTE_KEYS.AWS_BEDROCK_DATA_SOURCE_ID]).toBeUndefined();
563+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_SYSTEM]).toBe('aws_bedrock');
564+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MODEL]).toBe(modelId);
565+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_MAX_TOKENS]).toBe(512);
566+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TEMPERATURE]).toBe(0.5);
567+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_REQUEST_TOP_P]).toBe(0.65);
568+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_USAGE_INPUT_TOKENS]).toBe(10);
569+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_USAGE_OUTPUT_TOKENS]).toBe(4);
570+
expect(invokeModelSpan.attributes[AwsSpanProcessingUtil.GEN_AI_RESPONSE_FINISH_REASONS]).toEqual(['COMPLETE']);
571+
expect(invokeModelSpan.kind).toBe(SpanKind.CLIENT);
572+
});
573+
520574
it('Add Meta Llama model attributes to span', async () => {
521575
const modelId: string = 'meta.llama2-13b-chat-v1';
522576
const prompt: string = 'Describe the purpose of an interpreter program in one line.';

0 commit comments

Comments
 (0)