Skip to content
Draft

WIP #2890

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 177 additions & 0 deletions src/inspect_ai/model/_providers/_mistral_batch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
from typing import IO, TypeAlias

import pydantic
from mistralai import Mistral
from mistralai.models import BatchJobOut
from mistralai.models.conversationresponse import ConversationResponse
from typing_extensions import override

from inspect_ai.model._generate_config import BatchConfig
from inspect_ai.model._retry import ModelRetryConfig

from .util.batch import Batch, BatchCheckResult, BatchRequest
from .util.file_batcher import FileBatcher

# Just the output file ID
CompletedBatchInfo: TypeAlias = str


class MistralBatcher(FileBatcher[ConversationResponse, CompletedBatchInfo]):
def __init__(
self,
client: Mistral,
config: BatchConfig,
retry_config: ModelRetryConfig,
model_name: str,
):
super().__init__(
config=config,
retry_config=retry_config,
max_batch_request_count=1_000_000, # 1M ongoing requests per workspace
max_batch_size_mb=200, # Conservative estimate
)
self._client = client
self._model_name = model_name

# FileBatcher overrides

@override
def _jsonl_line_for_request(
self, request: BatchRequest[ConversationResponse], custom_id: str
) -> dict[str, pydantic.JsonValue]:
# Request body should already have model set, but filter out non-batch params
body = {
k: v
for k, v in request.request.items()
if k not in ("http_headers", "stream")
}
return {
"custom_id": custom_id,
"body": body,
}

@override
async def _upload_batch_file(
self, temp_file: IO[bytes], extra_headers: dict[str, str]
) -> str:
file_obj = await self._client.files.upload_async(
file={
"file_name": temp_file.name,
"content": temp_file,
},
purpose="batch",
)
return file_obj.id

@override
async def _submit_batch_for_file(
self, file_id: str, extra_headers: dict[str, str]
) -> str:
batch_job = await self._client.batch.jobs.create_async(
input_files=[file_id],
model=self._model_name,
endpoint="/v1/conversations",
)
return batch_job.id

@override
async def _download_result_file(self, file_id: str) -> bytes:
response = await self._client.files.download_async(file_id=file_id)
return response.read()

@override
def _parse_jsonl_line(
self, line_data: dict[str, pydantic.JsonValue]
) -> tuple[str, ConversationResponse | Exception]:
custom_id = line_data.get("custom_id", "")
assert isinstance(custom_id, str), "custom_id must be a string"

if not custom_id:
raise ValueError(
f"Unable to find custom_id in batched request result. {line_data}"
)

error = line_data.get("error")
if error is not None:
return custom_id, RuntimeError(str(error))

response = line_data.get("response")
if not isinstance(response, dict):
return custom_id, RuntimeError(f"Invalid response format: {line_data}")

status_code = response.get("status_code")
if status_code != 200:
body = response.get("body", {})
message = body.get("message", str(body)) if isinstance(body, dict) else body
return custom_id, RuntimeError(f"Request failed ({status_code}): {message}")

body = response.get("body")
if not isinstance(body, dict):
return custom_id, RuntimeError(f"Invalid response body: {response}")

return custom_id, ConversationResponse.model_validate(body)

@override
def _uris_from_completion_info(
self, completion_info: CompletedBatchInfo
) -> list[str]:
return [completion_info]

# Batcher overrides

@override
async def _check_batch(
self, batch: Batch[ConversationResponse]
) -> BatchCheckResult[CompletedBatchInfo]:
batch_job: BatchJobOut = await self._client.batch.jobs.get_async(
job_id=batch.id
)

# created_at is already a unix timestamp (int)
created_at = batch_job.created_at or batch.created_at

# Map Mistral batch statuses (status is a Literal string type)
status = batch_job.status
if status in ("QUEUED", "RUNNING"):
return BatchCheckResult(
completed_count=batch_job.succeeded_requests or 0,
failed_count=batch_job.failed_requests or 0,
created_at=created_at,
completion_info=None,
)
elif status == "SUCCESS":
output_file = batch_job.output_file
if not output_file:
raise RuntimeError(f"Batch {batch.id} succeeded but has no output file")
return BatchCheckResult(
completed_count=batch_job.succeeded_requests or len(batch.requests),
failed_count=batch_job.failed_requests or 0,
created_at=created_at,
completion_info=output_file,
)
elif status in (
"FAILED",
"TIMEOUT_EXCEEDED",
"CANCELLED",
"CANCELLATION_REQUESTED",
):
# Fail all requests in the batch
error_msg = f"Batch {batch.id} ended with status: {status}"
await self._resolve_inflight_batch(
batch,
{req_id: RuntimeError(error_msg) for req_id in batch.requests},
)
return BatchCheckResult(
completed_count=0,
failed_count=len(batch.requests),
created_at=created_at,
completion_info=None,
)
else:
# Unknown status - treat as pending
return BatchCheckResult(
completed_count=0,
failed_count=0,
created_at=created_at,
completion_info=None,
)
79 changes: 76 additions & 3 deletions src/inspect_ai/model/_providers/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,17 +63,21 @@
ChatMessage,
ChatMessageAssistant,
)
from .._generate_config import GenerateConfig
from .._model import ModelAPI
from .._generate_config import BatchConfig, GenerateConfig, normalized_batch_config
from .._model import ModelAPI, log_model_retry
from .._model_call import ModelCall
from .._model_output import (
ChatCompletionChoice,
ModelOutput,
ModelUsage,
StopReason,
)
from .._retry import model_retry_config
from ._mistral_batch import MistralBatcher
from .mistral_conversation import (
completion_choices_from_conversation_response,
mistral_conversation_generate,
mistral_conversation_request,
)
from .util import environment_prerequisite_error, model_base_url
from .util.hooks import HttpxHooks
Expand Down Expand Up @@ -149,24 +153,93 @@ def __init__(
model_args["server_url"] = self.base_url

self.model_args = model_args
self._batcher: MistralBatcher | None = None

def is_azure(self) -> bool:
return self.service == "azure"

async def _generate_batch(
self,
client: Mistral,
input: list[ChatMessage],
tools: list[ToolInfo],
tool_choice: ToolChoice,
config: GenerateConfig,
batch_config: BatchConfig,
) -> ModelOutput:
# initialize batcher if needed
if not self._batcher:
self._batcher = MistralBatcher(
client,
batch_config,
model_retry_config(
self.model_name,
config.max_retries,
config.timeout,
self.should_retry,
lambda ex: None,
log_model_retry,
),
self.service_model_name(),
)

# build request
request = await mistral_conversation_request(
self.service_model_name(), input, tools, tool_choice, config
)

# get response via batcher
conv_response = await self._batcher.generate_for_request(request)

# convert to ModelOutput
choices = completion_choices_from_conversation_response(
self.service_model_name(), conv_response, tools
)
return ModelOutput(
model=self.service_model_name(),
choices=choices,
usage=ModelUsage(
input_tokens=conv_response.usage.prompt_tokens or 0,
output_tokens=(
conv_response.usage.completion_tokens
if conv_response.usage.completion_tokens is not None
else (conv_response.usage.total_tokens or 0)
- (conv_response.usage.prompt_tokens or 0)
),
total_tokens=conv_response.usage.total_tokens or 0,
),
)

async def generate(
self,
input: list[ChatMessage],
tools: list[ToolInfo],
tool_choice: ToolChoice,
config: GenerateConfig,
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
# check for batch mode
batch_config = normalized_batch_config(config.batch)
if batch_config:
if self.is_azure():
raise ValueError("Batch mode is not supported for Azure Mistral.")
if not self.conversation_api:
raise ValueError(
"Batch mode requires conversation_api=True (the default)."
)

# create client
with Mistral(api_key=self.api_key, **self.model_args) as client:
# create time tracker
http_hooks = HttpxHooks(client.sdk_configuration.async_client)

# use the conversation api if requested
if self.conversation_api:
# handle batch mode for conversations API
if batch_config:
return await self._generate_batch(
client, input, tools, tool_choice, config, batch_config
)

return await mistral_conversation_generate(
client=client,
http_hooks=http_hooks,
Expand Down Expand Up @@ -271,7 +344,7 @@ def canonical_name(self) -> str:
return self.service_model_name()

@override
def should_retry(self, ex: Exception) -> bool:
def should_retry(self, ex: BaseException) -> bool:
if isinstance(ex, SDKError):
return is_retryable_http_status(ex.status_code)
elif httpx_should_retry(ex):
Expand Down
31 changes: 22 additions & 9 deletions src/inspect_ai/model/_providers/mistral_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,32 +64,45 @@
from .util.hooks import HttpxHooks


async def mistral_conversation_generate(
client: Mistral,
http_hooks: HttpxHooks,
async def mistral_conversation_request(
model: str,
input: list[ChatMessage],
tools: list[ToolInfo],
tool_choice: ToolChoice,
config: GenerateConfig,
handle_bad_request: Callable[[SDKError], ModelOutput | Exception],
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
# build request
request_id = http_hooks.start_request()
) -> dict[str, Any]:
"""Build a Mistral conversations API request dict (for direct calls or batching)."""
instructions, inputs = await mistral_conversation_inputs(input, config)
completion_args = mistral_conversation_completion_args(
config, tool_choice if len(tools) > 0 else None
)
request: dict[str, Any] = dict(
return dict(
model=model,
instructions=instructions or UNSET,
inputs=inputs,
tools=mistral_conversation_tools(tools) if len(tools) > 0 else UNSET,
completion_args=completion_args,
store=False,
http_headers={HttpxHooks.REQUEST_ID_HEADER: request_id},
)


async def mistral_conversation_generate(
client: Mistral,
http_hooks: HttpxHooks,
model: str,
input: list[ChatMessage],
tools: list[ToolInfo],
tool_choice: ToolChoice,
config: GenerateConfig,
handle_bad_request: Callable[[SDKError], ModelOutput | Exception],
) -> ModelOutput | tuple[ModelOutput | Exception, ModelCall]:
# build request
request_id = http_hooks.start_request()
request = await mistral_conversation_request(
model, input, tools, tool_choice, config
)
request["http_headers"] = {HttpxHooks.REQUEST_ID_HEADER: request_id}

# prepare response for inclusion in model call
response: dict[str, Any] = {}

Expand Down