Skip to content

Commit b8c6be2

Browse files
Add token parser for Bedrock & fix anthropic typo (#1851)
I am currently using Ragas with Bedrock and had to create new token usage parsers. This PR responds to [Issue-1151](#1151) Thanks for this package it's nice 😄
1 parent b3c768b commit b8c6be2

File tree

5 files changed

+100
-18
lines changed

5 files changed

+100
-18
lines changed

src/ragas/cost.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,33 @@ def get_token_usage_for_anthropic(
100100
return TokenUsage(input_tokens=0, output_tokens=0)
101101

102102

103+
def get_token_usage_for_bedrock(
104+
llm_result: t.Union[LLMResult, ChatResult],
105+
) -> TokenUsage:
106+
token_usages = []
107+
for gs in llm_result.generations:
108+
for g in gs:
109+
if isinstance(g, ChatGeneration):
110+
if g.message.response_metadata != {}:
111+
token_usages.append(
112+
TokenUsage(
113+
input_tokens=get_from_dict(
114+
g.message.response_metadata,
115+
"usage.prompt_tokens",
116+
0,
117+
),
118+
output_tokens=get_from_dict(
119+
g.message.response_metadata,
120+
"usage.completion_tokens",
121+
0,
122+
),
123+
)
124+
)
125+
126+
return sum(token_usages, TokenUsage(input_tokens=0, output_tokens=0))
127+
return TokenUsage(input_tokens=0, output_tokens=0)
128+
129+
103130
class CostCallbackHandler(BaseCallbackHandler):
104131
def __init__(self, token_usage_parser: TokenUsageParser):
105132
self.token_usage_parser = token_usage_parser

src/ragas/integrations/swarm.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import json
22
from typing import Any, Dict, List, Union
3+
34
from ragas.messages import AIMessage, HumanMessage, ToolCall, ToolMessage
45

56

src/ragas/llms/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ def is_finished(self, response: LLMResult) -> bool:
182182
elif resp_message.response_metadata.get("stop_reason") is not None:
183183
stop_reason = resp_message.response_metadata.get("stop_reason")
184184
is_finished_list.append(
185-
stop_reason in ["end_turn", "STOP", "MAX_TOKENS"]
185+
stop_reason in ["end_turn", "stop", "STOP", "MAX_TOKENS"]
186186
)
187187
# default to True
188188
else:

src/ragas/prompt/pydantic_prompt.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
from __future__ import annotations
22

33
import copy
4+
import hashlib
45
import json
56
import logging
67
import os
7-
import hashlib
8-
98
import typing as t
109

1110
from langchain_core.exceptions import OutputParserException
@@ -228,7 +227,7 @@ async def adapt(
228227
"""
229228
Adapt the prompt to a new language.
230229
"""
231-
230+
232231
strings = get_all_strings(self.examples)
233232
translated_strings = await translate_statements_prompt.generate(
234233
llm=llm,
@@ -275,7 +274,7 @@ def __str__(self):
275274
ensure_ascii=False,
276275
)[1:-1]
277276
return f"{self.__class__.__name__}({json_str})"
278-
277+
279278
def __hash__(self):
280279
# convert examples to json string for hashing
281280
examples = []
@@ -284,23 +283,23 @@ def __hash__(self):
284283
examples.append(
285284
(input_model.model_dump_json(), output_model.model_dump_json())
286285
)
287-
286+
288287
# create a SHA-256 hash object
289288
hasher = hashlib.sha256()
290-
289+
291290
# update the hash object with the bytes of each attribute
292-
hasher.update(self.name.encode('utf-8'))
293-
hasher.update(self.input_model.__name__.encode('utf-8'))
294-
hasher.update(self.output_model.__name__.encode('utf-8'))
295-
hasher.update(self.instruction.encode('utf-8'))
291+
hasher.update(self.name.encode("utf-8"))
292+
hasher.update(self.input_model.__name__.encode("utf-8"))
293+
hasher.update(self.output_model.__name__.encode("utf-8"))
294+
hasher.update(self.instruction.encode("utf-8"))
296295
for example in examples:
297-
hasher.update(example[0].encode('utf-8'))
298-
hasher.update(example[1].encode('utf-8'))
299-
hasher.update(self.language.encode('utf-8'))
300-
296+
hasher.update(example[0].encode("utf-8"))
297+
hasher.update(example[1].encode("utf-8"))
298+
hasher.update(self.language.encode("utf-8"))
299+
301300
# return the integer value of the hash
302301
return int(hasher.hexdigest(), 16)
303-
302+
304303
def __eq__(self, other):
305304
if not isinstance(other, PydanticPrompt):
306305
return False

tests/unit/test_cost.py

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
CostCallbackHandler,
77
TokenUsage,
88
get_token_usage_for_anthropic,
9+
get_token_usage_for_bedrock,
910
get_token_usage_for_openai,
1011
)
1112

@@ -62,7 +63,7 @@ def test_token_usage_cost():
6263
},
6364
)
6465

65-
athropic_llm_result = LLMResult(
66+
anthropic_llm_result = LLMResult(
6667
generations=[
6768
[
6869
ChatGeneration(
@@ -82,16 +83,70 @@ def test_token_usage_cost():
8283
llm_output={},
8384
)
8485

86+
bedrock_llama_result = LLMResult(
87+
generations=[
88+
[
89+
ChatGeneration(
90+
text="Hello, world!",
91+
message=AIMessage(
92+
content="Hello, world!",
93+
response_metadata={
94+
"usage": {
95+
"prompt_tokens": 10,
96+
"completion_tokens": 10,
97+
"total_tokens": 20,
98+
},
99+
"stop_reason": "stop",
100+
"model_id": "us.meta.llama3-1-70b-instruct-v1:0",
101+
},
102+
),
103+
)
104+
]
105+
],
106+
llm_output={},
107+
)
108+
109+
bedrock_claude_result = LLMResult(
110+
generations=[
111+
[
112+
ChatGeneration(
113+
text="Hello, world!",
114+
message=AIMessage(
115+
content="Hello, world!",
116+
response_metadata={
117+
"usage": {
118+
"prompt_tokens": 10,
119+
"completion_tokens": 10,
120+
"total_tokens": 20,
121+
},
122+
"stop_reason": "end_turn",
123+
"model_id": "us.anthropic.claude-3-5-sonnet-20240620-v1:0",
124+
},
125+
),
126+
)
127+
]
128+
],
129+
llm_output={},
130+
)
131+
85132

86133
def test_parse_llm_results():
87134
# openai
88135
token_usage = get_token_usage_for_openai(openai_llm_result)
89136
assert token_usage == TokenUsage(input_tokens=10, output_tokens=10)
90137

91138
# anthropic
92-
token_usage = get_token_usage_for_anthropic(athropic_llm_result)
139+
token_usage = get_token_usage_for_anthropic(anthropic_llm_result)
93140
assert token_usage == TokenUsage(input_tokens=9, output_tokens=12)
94141

142+
# Bedrock LLaMa
143+
token_usage = get_token_usage_for_bedrock(bedrock_llama_result)
144+
assert token_usage == TokenUsage(input_tokens=10, output_tokens=10)
145+
146+
# Bedrock Claude
147+
token_usage = get_token_usage_for_bedrock(bedrock_claude_result)
148+
assert token_usage == TokenUsage(input_tokens=10, output_tokens=10)
149+
95150

96151
def test_cost_callback_handler():
97152
cost_cb = CostCallbackHandler(token_usage_parser=get_token_usage_for_openai)

0 commit comments

Comments
 (0)