Skip to content

Commit a221682

Browse files
committed
lint
1 parent 7b9e4d5 commit a221682

File tree

1 file changed

+46
-45
lines changed

1 file changed

+46
-45
lines changed

src/any_llm/providers/base_framework.py

Lines changed: 46 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -87,12 +87,7 @@ def _convert_messages(self, messages: list[dict[str, Any]]) -> Any:
8787
pass
8888

8989
@abstractmethod
90-
def _make_api_call(
91-
self,
92-
model: str,
93-
messages: Any,
94-
**kwargs: Any
95-
) -> Any:
90+
def _make_api_call(self, model: str, messages: Any, **kwargs: Any) -> Any:
9691
"""Make the actual API call to the provider."""
9792
pass
9893

@@ -108,6 +103,7 @@ class BaseCustomProvider(BaseProviderFramework, ABC):
108103
109104
Examples: Anthropic, Google, Cohere, Mistral, Ollama
110105
"""
106+
111107
pass
112108

113109

@@ -152,50 +148,51 @@ def create_openai_completion(
152148

153149
# === NEW COMPREHENSIVE RESPONSE CONVERSION UTILITIES ===
154150

151+
155152
def create_tool_calls_from_list(tool_calls_data: list[dict[str, Any]]) -> list[ChatCompletionMessageToolCall]:
156153
"""
157154
Convert a list of tool call dictionaries to ChatCompletionMessageToolCall objects.
158-
155+
159156
Handles common variations in tool call structure across providers.
160157
"""
161158
tool_calls = []
162-
159+
163160
for tool_call in tool_calls_data:
164161
# Extract tool call ID (handle various formats)
165162
tool_call_id = tool_call.get("id") or tool_call.get("tool_call_id") or f"call_{hash(str(tool_call))}"
166-
163+
167164
# Extract function info (handle nested structures)
168165
function_info = tool_call.get("function", {})
169166
if not function_info and "name" in tool_call:
170167
# Some providers put function info directly in the tool_call
171168
function_info = {
172169
"name": tool_call["name"],
173-
"arguments": tool_call.get("arguments", tool_call.get("input", {}))
170+
"arguments": tool_call.get("arguments", tool_call.get("input", {})),
174171
}
175-
172+
176173
name = function_info.get("name", "")
177174
arguments = function_info.get("arguments", {})
178-
175+
179176
# Ensure arguments is a JSON string
180177
if isinstance(arguments, dict):
181178
arguments = json.dumps(arguments)
182179
elif not isinstance(arguments, str):
183180
arguments = str(arguments)
184-
181+
185182
tool_calls.append(create_openai_tool_call(tool_call_id, name, arguments))
186-
183+
187184
return tool_calls
188185

189186

190187
def create_choice_from_message_data(
191188
message_data: dict[str, Any],
192189
index: int = 0,
193190
finish_reason: str = "stop",
194-
finish_reason_mapping: Optional[dict[str, str]] = None
191+
finish_reason_mapping: Optional[dict[str, str]] = None,
195192
) -> Choice:
196193
"""
197194
Create a Choice object from message data, handling tool calls and content.
198-
195+
199196
Args:
200197
message_data: Dictionary containing message content and tool calls
201198
index: Choice index (default 0)
@@ -205,20 +202,20 @@ def create_choice_from_message_data(
205202
# Apply finish reason mapping if provided
206203
if finish_reason_mapping and finish_reason in finish_reason_mapping:
207204
finish_reason = finish_reason_mapping[finish_reason]
208-
205+
209206
# Extract tool calls if present
210207
tool_calls = None
211208
tool_calls_data = message_data.get("tool_calls", [])
212209
if tool_calls_data:
213210
tool_calls = create_tool_calls_from_list(tool_calls_data)
214-
211+
215212
# Create the message
216213
message = create_openai_message(
217214
role=message_data.get("role", "assistant"),
218215
content=message_data.get("content"),
219216
tool_calls=tool_calls,
220217
)
221-
218+
222219
return Choice(
223220
finish_reason=finish_reason, # type: ignore[arg-type]
224221
index=index,
@@ -227,29 +224,28 @@ def create_choice_from_message_data(
227224

228225

229226
def create_usage_from_data(
230-
usage_data: dict[str, Any],
231-
token_field_mapping: Optional[dict[str, str]] = None
227+
usage_data: dict[str, Any], token_field_mapping: Optional[dict[str, str]] = None
232228
) -> CompletionUsage:
233229
"""
234230
Create CompletionUsage from provider usage data.
235-
231+
236232
Args:
237233
usage_data: Dictionary containing usage information
238234
token_field_mapping: Optional mapping for field names (e.g., {"input_tokens": "prompt_tokens"})
239235
"""
240236
# Default field mapping
241237
default_mapping = {
242238
"completion_tokens": "completion_tokens",
243-
"prompt_tokens": "prompt_tokens",
239+
"prompt_tokens": "prompt_tokens",
244240
"total_tokens": "total_tokens",
245241
}
246-
242+
247243
# Apply custom mapping if provided
248244
if token_field_mapping:
249245
for openai_field, provider_field in token_field_mapping.items():
250246
if provider_field in usage_data:
251247
default_mapping[openai_field] = provider_field
252-
248+
253249
return CompletionUsage(
254250
completion_tokens=usage_data.get(default_mapping["completion_tokens"], 0),
255251
prompt_tokens=usage_data.get(default_mapping["prompt_tokens"], 0),
@@ -266,13 +262,13 @@ def create_completion_from_response(
266262
id_field: str = "id",
267263
created_field: str = "created",
268264
choices_field: str = "choices",
269-
usage_field: str = "usage"
265+
usage_field: str = "usage",
270266
) -> ChatCompletion:
271267
"""
272268
Create a complete ChatCompletion from provider response data.
273-
269+
274270
This is the main utility that most providers can use to convert their responses.
275-
271+
276272
Args:
277273
response_data: The raw response from the provider
278274
model: Model name to use in the response
@@ -287,28 +283,30 @@ def create_completion_from_response(
287283
# Extract choices
288284
choices = []
289285
choices_data = response_data.get(choices_field, [])
290-
286+
291287
# Handle single choice responses (common pattern)
292288
if not choices_data and "message" in response_data:
293-
choices_data = [{"message": response_data["message"], "finish_reason": response_data.get("finish_reason", "stop")}]
294-
289+
choices_data = [
290+
{"message": response_data["message"], "finish_reason": response_data.get("finish_reason", "stop")}
291+
]
292+
295293
for i, choice_data in enumerate(choices_data):
296294
choice = create_choice_from_message_data(
297295
choice_data.get("message", choice_data),
298296
index=i,
299297
finish_reason=choice_data.get("finish_reason", "stop"),
300-
finish_reason_mapping=finish_reason_mapping
298+
finish_reason_mapping=finish_reason_mapping,
301299
)
302300
choices.append(choice)
303-
301+
304302
# Create usage if available
305303
usage = None
306304
if usage_field in response_data and response_data[usage_field]:
307305
usage = create_usage_from_data(response_data[usage_field], token_field_mapping)
308-
306+
309307
# Generate ID if not present
310308
response_id = response_data.get(id_field, f"{provider_name}_{hash(str(response_data))}")
311-
309+
312310
return create_openai_completion(
313311
id=response_id,
314312
model=model,
@@ -320,32 +318,34 @@ def create_completion_from_response(
320318

321319
# === TOOL SPECIFICATION CONVERSION UTILITIES ===
322320

321+
323322
def convert_openai_tools_to_generic(openai_tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
324323
"""
325324
Convert OpenAI tool specification to a generic format that can be easily
326325
transformed to provider-specific formats.
327-
326+
328327
Returns a list of tool dictionaries with standardized structure.
329328
"""
330329
generic_tools = []
331-
330+
332331
for tool in openai_tools:
333332
if tool.get("type") != "function":
334333
continue
335-
334+
336335
function = tool["function"]
337336
generic_tool = {
338337
"name": function["name"],
339338
"description": function.get("description", ""),
340339
"parameters": function.get("parameters", {}),
341340
}
342341
generic_tools.append(generic_tool)
343-
342+
344343
return generic_tools
345344

346345

347346
# === MESSAGE CONVERSION UTILITIES ===
348347

348+
349349
def standardize_message_content(content: Any) -> str:
350350
"""Convert message content to string format, handling various input types."""
351351
if content is None:
@@ -360,23 +360,24 @@ def standardize_message_content(content: Any) -> str:
360360
def extract_system_message(messages: list[dict[str, Any]]) -> tuple[str, list[dict[str, Any]]]:
361361
"""
362362
Extract system message from messages list.
363-
363+
364364
Returns tuple of (system_message_content, remaining_messages)
365365
"""
366366
system_message = ""
367367
remaining_messages = []
368-
368+
369369
for message in messages:
370370
if message.get("role") == "system":
371371
system_message = standardize_message_content(message.get("content", ""))
372372
else:
373373
remaining_messages.append(message)
374-
374+
375375
return system_message, remaining_messages
376376

377377

378378
# === PARAMETER CONVERSION UTILITIES ===
379379

380+
380381
def remove_unsupported_params(kwargs: dict[str, Any], unsupported: list[str]) -> dict[str, Any]:
381382
"""Remove unsupported parameters from kwargs."""
382383
cleaned_kwargs = kwargs.copy()
@@ -388,9 +389,9 @@ def remove_unsupported_params(kwargs: dict[str, Any], unsupported: list[str]) ->
388389
def map_parameter_names(kwargs: dict[str, Any], param_mapping: dict[str, str]) -> dict[str, Any]:
389390
"""Map parameter names from OpenAI format to provider format."""
390391
mapped_kwargs = {}
391-
392+
392393
for key, value in kwargs.items():
393394
mapped_key = param_mapping.get(key, key)
394395
mapped_kwargs[mapped_key] = value
395-
396-
return mapped_kwargs
396+
397+
return mapped_kwargs

0 commit comments

Comments
 (0)