Skip to content

Commit 2dae552

Browse files
OpenAI model cost (#2287)
fixed formatting issues from #1946 --------- Co-authored-by: JonasElburgUVA <[email protected]>
1 parent abfd849 commit 2dae552

File tree

2 files changed

+34
-10
lines changed

2 files changed

+34
-10
lines changed

src/ragas/cost.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __add__(self, y: "TokenUsage") -> "TokenUsage":
2222
return TokenUsage(
2323
input_tokens=self.input_tokens + y.input_tokens,
2424
output_tokens=self.output_tokens + y.output_tokens,
25+
model=self.model,
2526
)
2627
else:
2728
raise ValueError("Cannot add TokenUsage objects with different models")
@@ -67,8 +68,11 @@ def get_token_usage_for_openai(
6768
return TokenUsage(input_tokens=0, output_tokens=0)
6869
output_tokens = get_from_dict(llm_output, "token_usage.completion_tokens", 0)
6970
input_tokens = get_from_dict(llm_output, "token_usage.prompt_tokens", 0)
71+
model = get_from_dict(llm_output, "model_name", "")
7072

71-
return TokenUsage(input_tokens=input_tokens, output_tokens=output_tokens)
73+
return TokenUsage(
74+
input_tokens=input_tokens, output_tokens=output_tokens, model=model
75+
)
7276

7377

7478
def get_token_usage_for_anthropic(
@@ -92,10 +96,15 @@ def get_token_usage_for_anthropic(
9296
"usage.output_tokens",
9397
0,
9498
),
99+
model=get_from_dict(
100+
g.message.response_metadata, "model", ""
101+
),
95102
)
96103
)
97-
98-
return sum(token_usages, TokenUsage(input_tokens=0, output_tokens=0))
104+
model = next((usage.model for usage in token_usages if usage.model), "")
105+
return sum(
106+
token_usages, TokenUsage(input_tokens=0, output_tokens=0, model=model)
107+
)
99108
else:
100109
return TokenUsage(input_tokens=0, output_tokens=0)
101110

@@ -120,10 +129,15 @@ def get_token_usage_for_bedrock(
120129
"usage.completion_tokens",
121130
0,
122131
),
132+
model=get_from_dict(
133+
g.message.response_metadata, "model_id", ""
134+
),
123135
)
124136
)
125-
126-
return sum(token_usages, TokenUsage(input_tokens=0, output_tokens=0))
137+
model = next((usage.model for usage in token_usages if usage.model), "")
138+
return sum(
139+
token_usages, TokenUsage(input_tokens=0, output_tokens=0, model=model)
140+
)
127141
return TokenUsage(input_tokens=0, output_tokens=0)
128142

129143

tests/unit/test_cost.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,27 +133,37 @@ def test_token_usage_cost():
133133
def test_parse_llm_results():
134134
# openai
135135
token_usage = get_token_usage_for_openai(openai_llm_result)
136-
assert token_usage == TokenUsage(input_tokens=10, output_tokens=10)
136+
assert token_usage == TokenUsage(input_tokens=10, output_tokens=10, model="gpt-4o")
137137

138138
# anthropic
139139
token_usage = get_token_usage_for_anthropic(anthropic_llm_result)
140-
assert token_usage == TokenUsage(input_tokens=9, output_tokens=12)
140+
assert token_usage == TokenUsage(
141+
input_tokens=9, output_tokens=12, model="claude-3-opus-20240229"
142+
)
141143

142144
# Bedrock LLaMa
143145
token_usage = get_token_usage_for_bedrock(bedrock_llama_result)
144-
assert token_usage == TokenUsage(input_tokens=10, output_tokens=10)
146+
assert token_usage == TokenUsage(
147+
input_tokens=10, output_tokens=10, model="us.meta.llama3-1-70b-instruct-v1:0"
148+
)
145149

146150
# Bedrock Claude
147151
token_usage = get_token_usage_for_bedrock(bedrock_claude_result)
148-
assert token_usage == TokenUsage(input_tokens=10, output_tokens=10)
152+
assert token_usage == TokenUsage(
153+
input_tokens=10,
154+
output_tokens=10,
155+
model="us.anthropic.claude-3-5-sonnet-20240620-v1:0",
156+
)
149157

150158

151159
def test_cost_callback_handler():
152160
cost_cb = CostCallbackHandler(token_usage_parser=get_token_usage_for_openai)
153161
cost_cb.on_llm_end(openai_llm_result)
154162

155163
# cost
156-
assert cost_cb.total_tokens() == TokenUsage(input_tokens=10, output_tokens=10)
164+
assert cost_cb.total_tokens() == TokenUsage(
165+
input_tokens=10, output_tokens=10, model="gpt-4o"
166+
)
157167

158168
assert cost_cb.total_cost(0.1) == 2.0
159169
assert (

0 commit comments

Comments
 (0)