Skip to content

Commit dc03217

Browse files
committed
[feat] add support for custom model pricing and update cost calculation logic
1 parent 3365a07 commit dc03217

File tree

2 files changed

+25
-19
lines changed

2 files changed

+25
-19
lines changed

edenai_apis/llmengine/clients/litellm_client/litellm_client.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,8 @@ def completion(
187187
call_params["model_list"] = model_list
188188
if user is not None:
189189
call_params["user"] = user
190-
# See if there's a custom pricing here
190+
# See if there's custom pricing (model_pricing for extended pricing, or legacy per-token pricing)
191+
model_pricing = kwargs.pop("model_pricing", None)
191192
custom_pricing = {}
192193
if kwargs.get("input_cost_per_token", None) and kwargs.get(
193194
"output_cost_per_token", None
@@ -200,6 +201,13 @@ def completion(
200201
if drop_invalid_params == True:
201202
litellm.drop_params = True
202203
kwargs.pop("moderate_content", None)
204+
# Register custom model pricing in litellm's registry for extended pricing support
205+
if model_pricing:
206+
# Merge with existing pricing to preserve other fields (max_tokens, mode, etc.)
207+
if model_name in litellm.model_cost:
208+
litellm.model_cost[model_name].update(model_pricing)
209+
else:
210+
litellm.model_cost[model_name] = model_pricing
203211
provider_start_time = time.time_ns()
204212
c_response = completion(**call_params, **kwargs)
205213
provider_end_time = time.time_ns()
@@ -216,7 +224,8 @@ def generate_chunks():
216224
"completion_response": c_response,
217225
"call_type": "completion",
218226
}
219-
if len(custom_pricing.keys()) > 0:
227+
# Use model_pricing via registry lookup, or fall back to legacy custom_cost_per_token
228+
if not model_pricing and len(custom_pricing.keys()) > 0:
220229
cost_calc_params["custom_cost_per_token"] = custom_pricing
221230
response = {
222231
**c_response.model_dump(),
@@ -807,7 +816,8 @@ async def acompletion(
807816
call_params["model_list"] = model_list
808817
if user is not None:
809818
call_params["user"] = user
810-
# See if there's a custom pricing here
819+
# See if there's custom pricing (model_pricing for extended pricing, or legacy per-token pricing)
820+
model_pricing = kwargs.pop("model_pricing", None)
811821
custom_pricing = {}
812822
if kwargs.get("input_cost_per_token", None) and kwargs.get(
813823
"output_cost_per_token", None
@@ -820,6 +830,13 @@ async def acompletion(
820830
if drop_invalid_params == True:
821831
litellm.drop_params = True
822832
kwargs.pop("moderate_content", None)
833+
# Register custom model pricing in litellm's registry for extended pricing support
834+
if model_pricing:
835+
# Merge with existing pricing to preserve other fields (max_tokens, mode, etc.)
836+
if model_name in litellm.model_cost:
837+
litellm.model_cost[model_name].update(model_pricing)
838+
else:
839+
litellm.model_cost[model_name] = model_pricing
823840
provider_start_time = time.time_ns()
824841
c_response = await acompletion(**call_params, **kwargs)
825842
provider_end_time = time.time_ns()
@@ -834,9 +851,10 @@ async def generate_chunks():
834851
else:
835852
cost_calc_params = {
836853
"completion_response": c_response,
837-
"call_type": "completion",
854+
"call_type": "acompletion",
838855
}
839-
if len(custom_pricing.keys()) > 0:
856+
# Use model_pricing via registry lookup, or fall back to legacy custom_cost_per_token
857+
if not model_pricing and len(custom_pricing.keys()) > 0:
840858
cost_calc_params["custom_cost_per_token"] = custom_pricing
841859
response = {
842860
**c_response.model_dump(),

edenai_apis/llmengine/utils/calculate_cost.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,6 @@
55

66
def calculate_cost(
77
completion_response: Union[ModelResponse, dict],
8-
model: str,
9-
call_type: Literal[
10-
"completion",
11-
"embedding",
12-
"image_generation",
13-
"moderation",
14-
"acompletion",
15-
"aembedding",
16-
"aimage_generation",
17-
"amoderation",
18-
"arerank",
19-
] = "completion",
208
input_cost_per_token: Optional[float] = None,
219
output_cost_per_token: Optional[float] = None,
2210
) -> float:
@@ -35,7 +23,7 @@ def calculate_cost(
3523
"""
3624
cost_calc_params = {
3725
"completion_response": completion_response,
38-
"call_type": call_type,
26+
"call_type": "acompletion", # For now, we only support completion cost calculation
3927
}
4028

4129
if input_cost_per_token is not None and output_cost_per_token is not None:
@@ -44,4 +32,4 @@ def calculate_cost(
4432
"output_cost_per_token": output_cost_per_token,
4533
}
4634

47-
return completion_cost(**cost_calc_params, model=model)
35+
return completion_cost(**cost_calc_params)

0 commit comments

Comments
 (0)