diff --git a/src/inspect_ai/model/_providers/_mistral_batch.py b/src/inspect_ai/model/_providers/_mistral_batch.py new file mode 100644 index 0000000000..c6435579b6 --- /dev/null +++ b/src/inspect_ai/model/_providers/_mistral_batch.py @@ -0,0 +1,177 @@ +from typing import IO, TypeAlias + +import pydantic +from mistralai import Mistral +from mistralai.models import BatchJobOut +from mistralai.models.conversationresponse import ConversationResponse +from typing_extensions import override + +from inspect_ai.model._generate_config import BatchConfig +from inspect_ai.model._retry import ModelRetryConfig + +from .util.batch import Batch, BatchCheckResult, BatchRequest +from .util.file_batcher import FileBatcher + +# Just the output file ID +CompletedBatchInfo: TypeAlias = str + + +class MistralBatcher(FileBatcher[ConversationResponse, CompletedBatchInfo]): + def __init__( + self, + client: Mistral, + config: BatchConfig, + retry_config: ModelRetryConfig, + model_name: str, + ): + super().__init__( + config=config, + retry_config=retry_config, + max_batch_request_count=1_000_000, # 1M ongoing requests per workspace + max_batch_size_mb=200, # Conservative estimate + ) + self._client = client + self._model_name = model_name + + # FileBatcher overrides + + @override + def _jsonl_line_for_request( + self, request: BatchRequest[ConversationResponse], custom_id: str + ) -> dict[str, pydantic.JsonValue]: + # Request body should already have model set, but filter out non-batch params + body = { + k: v + for k, v in request.request.items() + if k not in ("http_headers", "stream") + } + return { + "custom_id": custom_id, + "body": body, + } + + @override + async def _upload_batch_file( + self, temp_file: IO[bytes], extra_headers: dict[str, str] + ) -> str: + file_obj = await self._client.files.upload_async( + file={ + "file_name": temp_file.name, + "content": temp_file, + }, + purpose="batch", + ) + return file_obj.id + + @override + async def _submit_batch_for_file( + self, file_id: str, extra_headers: dict[str, str] + ) -> str: + batch_job = await self._client.batch.jobs.create_async( + input_files=[file_id], + model=self._model_name, + endpoint="/v1/conversations", + ) + return batch_job.id + + @override + async def _download_result_file(self, file_id: str) -> bytes: + response = await self._client.files.download_async(file_id=file_id) + return response.read() + + @override + def _parse_jsonl_line( + self, line_data: dict[str, pydantic.JsonValue] + ) -> tuple[str, ConversationResponse | Exception]: + custom_id = line_data.get("custom_id", "") + assert isinstance(custom_id, str), "custom_id must be a string" + + if not custom_id: + raise ValueError( + f"Unable to find custom_id in batched request result. {line_data}" + ) + + error = line_data.get("error") + if error is not None: + return custom_id, RuntimeError(str(error)) + + response = line_data.get("response") + if not isinstance(response, dict): + return custom_id, RuntimeError(f"Invalid response format: {line_data}") + + status_code = response.get("status_code") + if status_code != 200: + body = response.get("body", {}) + message = body.get("message", str(body)) if isinstance(body, dict) else body + return custom_id, RuntimeError(f"Request failed ({status_code}): {message}") + + body = response.get("body") + if not isinstance(body, dict): + return custom_id, RuntimeError(f"Invalid response body: {response}") + + return custom_id, ConversationResponse.model_validate(body) + + @override + def _uris_from_completion_info( + self, completion_info: CompletedBatchInfo + ) -> list[str]: + return [completion_info] + + # Batcher overrides + + @override + async def _check_batch( + self, batch: Batch[ConversationResponse] + ) -> BatchCheckResult[CompletedBatchInfo]: + batch_job: BatchJobOut = await self._client.batch.jobs.get_async( + job_id=batch.id + ) + + # created_at is already a unix timestamp (int) + created_at = batch_job.created_at or batch.created_at + + # Map Mistral batch statuses (status is a Literal string type) + status = batch_job.status + if status in ("QUEUED", "RUNNING"): + return BatchCheckResult( + completed_count=batch_job.succeeded_requests or 0, + failed_count=batch_job.failed_requests or 0, + created_at=created_at, + completion_info=None, + ) + elif status == "SUCCESS": + output_file = batch_job.output_file + if not output_file: + raise RuntimeError(f"Batch {batch.id} succeeded but has no output file") + return BatchCheckResult( + completed_count=batch_job.succeeded_requests or len(batch.requests), + failed_count=batch_job.failed_requests or 0, + created_at=created_at, + completion_info=output_file, + ) + elif status in ( + "FAILED", + "TIMEOUT_EXCEEDED", + "CANCELLED", + "CANCELLATION_REQUESTED", + ): + # Fail all requests in the batch + error_msg = f"Batch {batch.id} ended with status: {status}" + await self._resolve_inflight_batch( + batch, + {req_id: RuntimeError(error_msg) for req_id in batch.requests}, + ) + return BatchCheckResult( + completed_count=0, + failed_count=len(batch.requests), + created_at=created_at, + completion_info=None, + ) + else: + # Unknown status - treat as pending + return BatchCheckResult( + completed_count=0, + failed_count=0, + created_at=created_at, + completion_info=None, + ) diff --git a/src/inspect_ai/model/_providers/mistral.py b/src/inspect_ai/model/_providers/mistral.py index b70223bc98..7fb9ada015 100644 --- a/src/inspect_ai/model/_providers/mistral.py +++ b/src/inspect_ai/model/_providers/mistral.py @@ -63,8 +63,8 @@ ChatMessage, ChatMessageAssistant, ) -from .._generate_config import GenerateConfig -from .._model import ModelAPI +from .._generate_config import BatchConfig, GenerateConfig, normalized_batch_config +from .._model import ModelAPI, log_model_retry from .._model_call import ModelCall from .._model_output import ( ChatCompletionChoice, @@ -72,8 +72,12 @@ ModelUsage, StopReason, ) +from .._retry import model_retry_config +from ._mistral_batch import MistralBatcher from .mistral_conversation import ( + completion_choices_from_conversation_response, mistral_conversation_generate, + mistral_conversation_request, ) from .util import environment_prerequisite_error, model_base_url from .util.hooks import HttpxHooks @@ -149,10 +153,63 @@ def __init__( model_args["server_url"] = self.base_url self.model_args = model_args + self._batcher: MistralBatcher | None = None def is_azure(self) -> bool: return self.service == "azure" + async def _generate_batch( + self, + client: Mistral, + input: list[ChatMessage], + tools: list[ToolInfo], + tool_choice: ToolChoice, + config: GenerateConfig, + batch_config: BatchConfig, + ) -> ModelOutput: + # initialize batcher if needed + if not self._batcher: + self._batcher = MistralBatcher( + client, + batch_config, + model_retry_config( + self.model_name, + config.max_retries, + config.timeout, + self.should_retry, + lambda ex: None, + log_model_retry, + ), + self.service_model_name(), + ) + + # build request + request = await mistral_conversation_request( + self.service_model_name(), input, tools, tool_choice, config + ) + + # get response via batcher + conv_response = await self._batcher.generate_for_request(request) + + # convert to ModelOutput + choices = completion_choices_from_conversation_response( + self.service_model_name(), conv_response, tools + ) + return ModelOutput( + model=self.service_model_name(), + choices=choices, + usage=ModelUsage( + input_tokens=conv_response.usage.prompt_tokens or 0, + output_tokens=( + conv_response.usage.completion_tokens + if conv_response.usage.completion_tokens is not None + else (conv_response.usage.total_tokens or 0) + - (conv_response.usage.prompt_tokens or 0) + ), + total_tokens=conv_response.usage.total_tokens or 0, + ), + ) + async def generate( self, input: list[ChatMessage], @@ -160,6 +217,16 @@ async def generate( tool_choice: ToolChoice, config: GenerateConfig, ) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]: + # check for batch mode + batch_config = normalized_batch_config(config.batch) + if batch_config: + if self.is_azure(): + raise ValueError("Batch mode is not supported for Azure Mistral.") + if not self.conversation_api: + raise ValueError( + "Batch mode requires conversation_api=True (the default)." + ) + # create client with Mistral(api_key=self.api_key, **self.model_args) as client: # create time tracker @@ -167,6 +234,12 @@ async def generate( # use the conversation api if requested if self.conversation_api: + # handle batch mode for conversations API + if batch_config: + return await self._generate_batch( + client, input, tools, tool_choice, config, batch_config + ) + return await mistral_conversation_generate( client=client, http_hooks=http_hooks, @@ -271,7 +344,7 @@ def canonical_name(self) -> str: return self.service_model_name() @override - def should_retry(self, ex: Exception) -> bool: + def should_retry(self, ex: BaseException) -> bool: if isinstance(ex, SDKError): return is_retryable_http_status(ex.status_code) elif httpx_should_retry(ex): diff --git a/src/inspect_ai/model/_providers/mistral_conversation.py b/src/inspect_ai/model/_providers/mistral_conversation.py index 0d57c88890..bd4dfb7ebb 100644 --- a/src/inspect_ai/model/_providers/mistral_conversation.py +++ b/src/inspect_ai/model/_providers/mistral_conversation.py @@ -64,32 +64,45 @@ from .util.hooks import HttpxHooks -async def mistral_conversation_generate( - client: Mistral, - http_hooks: HttpxHooks, +async def mistral_conversation_request( model: str, input: list[ChatMessage], tools: list[ToolInfo], tool_choice: ToolChoice, config: GenerateConfig, - handle_bad_request: Callable[[SDKError], ModelOutput | Exception], -) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]: - # build request - request_id = http_hooks.start_request() +) -> dict[str, Any]: + """Build a Mistral conversations API request dict (for direct calls or batching).""" instructions, inputs = await mistral_conversation_inputs(input, config) completion_args = mistral_conversation_completion_args( config, tool_choice if len(tools) > 0 else None ) - request: dict[str, Any] = dict( + return dict( model=model, instructions=instructions or UNSET, inputs=inputs, tools=mistral_conversation_tools(tools) if len(tools) > 0 else UNSET, completion_args=completion_args, store=False, - http_headers={HttpxHooks.REQUEST_ID_HEADER: request_id}, ) + +async def mistral_conversation_generate( + client: Mistral, + http_hooks: HttpxHooks, + model: str, + input: list[ChatMessage], + tools: list[ToolInfo], + tool_choice: ToolChoice, + config: GenerateConfig, + handle_bad_request: Callable[[SDKError], ModelOutput | Exception], +) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]: + # build request + request_id = http_hooks.start_request() + request = await mistral_conversation_request( + model, input, tools, tool_choice, config + ) + request["http_headers"] = {HttpxHooks.REQUEST_ID_HEADER: request_id} + # prepare response for inclusion in model call response: dict[str, Any] = {}