Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion marimo/_server/ai/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def for_bedrock(cls, config: AiConfig) -> AnyProviderConfig:
ai_config = _get_ai_config(config, "bedrock")
key = _get_key(ai_config, "Bedrock")
return cls(
base_url=_get_base_url(ai_config),
base_url=_get_base_url(ai_config, "Bedrock"),
api_key=key,
tools=_get_tools(config.get("mode", "manual")),
)
Expand Down
212 changes: 63 additions & 149 deletions marimo/_server/ai/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,12 +49,6 @@
if TYPE_CHECKING:
# Used for Bedrock, unified interface for all models
from anthropic.types.beta import BetaThinkingConfigParam
from litellm import ( # type: ignore[attr-defined]
CustomStreamWrapper as LitellmStream,
)
from litellm.types.utils import (
ModelResponseStream as LitellmStreamResponse,
)
from openai import ( # type: ignore[import-not-found]
AsyncOpenAI,
AsyncStream as OpenAiStream,
Expand All @@ -64,10 +58,16 @@
)
from pydantic_ai import Agent, DeferredToolRequests, FunctionToolset
from pydantic_ai.messages import ThinkingPart
from pydantic_ai.models import Model
from pydantic_ai.models.bedrock import BedrockConverseModel
from pydantic_ai.models.google import GoogleModel
from pydantic_ai.providers import Provider
from pydantic_ai.providers.anthropic import (
AnthropicProvider as PydanticAnthropic,
)
from pydantic_ai.providers.bedrock import (
BedrockProvider as PydanticBedrock,
)
from pydantic_ai.providers.google import GoogleProvider as PydanticGoogle
from pydantic_ai.ui.vercel_ai import VercelAIAdapter
from pydantic_ai.ui.vercel_ai.request_types import UIMessage, UIMessagePart
Expand Down Expand Up @@ -148,10 +148,26 @@ def create_provider(self, config: AnyProviderConfig) -> ProviderT:
"""Create a provider for the given config."""

@abstractmethod
def create_model(self, max_tokens: int) -> Model:
"""Create a Pydantic AI model for the given max tokens."""

def create_agent(
self, max_tokens: int, tools: list[ToolDefinition], system_prompt: str
self,
max_tokens: int,
tools: list[ToolDefinition],
system_prompt: str,
) -> Agent[None, DeferredToolRequests | str]:
"""Create a Pydantic AI agent"""
from pydantic_ai import Agent

model = self.create_model(max_tokens)
toolset, output_type = self._get_toolsets_and_output_type(tools)
return Agent(
model,
toolsets=[toolset] if tools else None,
instructions=system_prompt,
output_type=output_type,
)

def get_vercel_adapter(self) -> type[VercelAIAdapter[Any, Any]]:
"""Return the Vercel AI adapter for the given provider."""
Expand Down Expand Up @@ -290,27 +306,17 @@ def create_provider(self, config: AnyProviderConfig) -> PydanticGoogle:
provider = PydanticGoogle()
return provider

def create_agent(
self, max_tokens: int, tools: list[ToolDefinition], system_prompt: str
) -> Agent[None, DeferredToolRequests | str]:
from pydantic_ai import Agent
def create_model(self, max_tokens: int) -> GoogleModel:
from pydantic_ai.models.google import GoogleModel, GoogleModelSettings

toolset, output_type = self._get_toolsets_and_output_type(tools)

return Agent(
GoogleModel(
model_name=self.model,
provider=self.provider,
settings=GoogleModelSettings(
max_tokens=max_tokens,
# Works on non-thinking models too
google_thinking_config={"include_thoughts": True},
),
return GoogleModel(
model_name=self.model,
provider=self.provider,
settings=GoogleModelSettings(
max_tokens=max_tokens,
# Works on non-thinking models too
google_thinking_config={"include_thoughts": True},
),
toolsets=[toolset] if tools else None,
instructions=system_prompt,
output_type=output_type,
)


Expand Down Expand Up @@ -953,16 +959,12 @@ def create_provider(self, config: AnyProviderConfig) -> PydanticAnthropic:

return PydanticAnthropic(api_key=config.api_key)

def create_agent(
self, max_tokens: int, tools: list[ToolDefinition], system_prompt: str
) -> Agent[None, DeferredToolRequests | str]:
from pydantic_ai import Agent
def create_model(self, max_tokens: int) -> Model:
from pydantic_ai.models.anthropic import (
AnthropicModel,
AnthropicModelSettings,
)

toolset, output_type = self._get_toolsets_and_output_type(tools)
is_thinking_model = self.is_extended_thinking_model(self.model)
thinking_config: BetaThinkingConfigParam = {"type": "disabled"}
if is_thinking_model:
Expand All @@ -971,19 +973,14 @@ def create_agent(
"budget_tokens": self.DEFAULT_EXTENDED_THINKING_BUDGET_TOKENS,
}

return Agent(
AnthropicModel(
model_name=self.model,
provider=self.provider,
settings=AnthropicModelSettings(
max_tokens=max_tokens,
temperature=self.get_temperature(),
anthropic_thinking=thinking_config,
),
return AnthropicModel(
model_name=self.model,
provider=self.provider,
settings=AnthropicModelSettings(
max_tokens=max_tokens,
temperature=self.get_temperature(),
anthropic_thinking=thinking_config,
),
toolsets=[toolset] if tools else None,
instructions=system_prompt,
output_type=output_type,
)

def is_extended_thinking_model(self, model: str) -> bool:
Expand Down Expand Up @@ -1088,12 +1085,7 @@ def build_event_stream(self) -> AnthropicVercelAIEventStream:
return AnthropicVercelAIAdapter


class BedrockProvider(
CompletionProvider[
"LitellmStreamResponse",
"LitellmStream",
]
):
class BedrockProvider(PydanticProvider["PydanticBedrock"]):
def setup_credentials(self, config: AnyProviderConfig) -> None:
# Use profile name if provided, otherwise use API key
try:
Expand All @@ -1113,109 +1105,29 @@ def setup_credentials(self, config: AnyProviderConfig) -> None:
detail="Error setting up AWS credentials",
) from e

async def stream_completion(
self,
messages: list[ChatMessage],
system_prompt: str,
max_tokens: int,
additional_tools: list[ToolDefinition],
) -> LitellmStream:
DependencyManager.litellm.require(why="for AI assistance with Bedrock")
DependencyManager.boto3.require(why="for AI assistance with Bedrock")
from litellm import acompletion as litellm_completion

self.setup_credentials(self.config)
tools = self.config.tools

config = {
"model": self.model,
"messages": cast(
Any,
convert_to_openai_messages(
[ChatMessage(role="system", content=system_prompt)]
+ messages
),
),
"max_completion_tokens": max_tokens,
"stream": True,
"timeout": TIMEOUT,
}
if tools:
all_tools = tools + additional_tools
config["tools"] = convert_to_openai_tools(all_tools)

return await litellm_completion(**config)

def extract_content(
self,
response: LitellmStreamResponse,
tool_call_ids: Optional[list[str]] = None,
) -> Optional[ExtractedContentList]:
tool_call_ids = tool_call_ids or []
if (
hasattr(response, "choices")
and response.choices
and response.choices[0].delta
):
delta = response.choices[0].delta

# Text content
content = delta.content
if content:
return [(str(content), "text")]

# Tool call: LiteLLM follows OpenAI format for tool calls

if hasattr(delta, "tool_calls") and delta.tool_calls:
tool_content: ExtractedContentList = []

for tool_call in delta.tool_calls:
tool_index: int = tool_call.index

# Start of tool call
# id is only present for the first tool call chunk
if hasattr(tool_call, "id") and tool_call.id:
tool_info = {
"toolCallId": tool_call.id,
"toolName": tool_call.function.name,
}
tool_content.append((tool_info, "tool_call_start"))

# Delta of tool call
# arguments is only present second chunk onwards
if (
hasattr(tool_call, "function")
and tool_call.function
and hasattr(tool_call.function, "arguments")
and tool_call.function.arguments
and tool_index < len(tool_call_ids)
and tool_call_ids[tool_index]
):
tool_delta = {
"toolCallId": tool_call_ids[tool_index],
"inputTextDelta": tool_call.function.arguments,
}
tool_content.append((tool_delta, "tool_call_delta"))
def create_provider(self, config: AnyProviderConfig) -> PydanticBedrock:
from pydantic_ai.providers.bedrock import (
BedrockProvider as PydanticBedrock,
)

# return the tool content
return tool_content
self.setup_credentials(config)
# For bedrock, the config sets the region name as the base_url
return PydanticBedrock(region_name=config.base_url)

return None
def create_model(self, max_tokens: int) -> BedrockConverseModel:
from pydantic_ai.models.bedrock import (
BedrockConverseModel,
BedrockModelSettings,
)

def get_finish_reason(
self, response: LitellmStreamResponse
) -> Optional[FinishReason]:
if (
hasattr(response, "choices")
and response.choices
and response.choices[0].finish_reason
):
return (
"tool_calls"
if response.choices[0].finish_reason == "tool_calls"
else "stop"
)
return None
return BedrockConverseModel(
model_name=self.model,
provider=self.provider,
settings=BedrockModelSettings(
max_tokens=max_tokens,
# TODO: Add reasoning support
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we didn't have reasoning before, so I did not add new logic to support this.

),
)


def get_completion_provider(
Expand All @@ -1232,7 +1144,9 @@ def get_completion_provider(
model_id.model, config, [DependencyManager.google_ai]
)
elif model_id.provider == "bedrock":
return BedrockProvider(model_id.model, config)
return BedrockProvider(
model_id.model, config, [DependencyManager.boto3]
)
elif model_id.provider == "azure":
return AzureOpenAIProvider(model_id.model, config)
elif model_id.provider == "openrouter":
Expand Down
13 changes: 7 additions & 6 deletions pixi.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ dependencies = [
"matplotlib>=3.8.0",
"sqlglot[rs]>=26.2.0",
"sqlalchemy>=2.0.40",
"pydantic-ai-slim[google,anthropic]>=1.39.0",
"pydantic-ai-slim[google,anthropic,bedrock]>=1.39.0",
"openai>=1.55.3",
"loro>=1.5.0",
"pandas-stubs>=1.5.3.230321",
Expand Down Expand Up @@ -260,10 +260,11 @@ extra-dependencies = [
"ipython~=8.12.3",
# testing gen ai
"openai>=1.55.3",
"pydantic-ai-slim[google,anthropic]>=1.39.0",
"pydantic-ai-slim[google,anthropic,bedrock]>=1.39.0",
# - google-auth uses cachetools, and cachetools<5.0.0 uses collections.MutableMapping (removed in Python 3.10)
"cachetools>=5.0.0",
"boto3>=1.38.46",
# for bedrock ui chat
"litellm>=1.70.0",
# exporting as ipynb
"nbformat>=5.10.4",
Expand Down
6 changes: 2 additions & 4 deletions tests/_server/ai/test_ai_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,8 +359,7 @@ def test_for_bedrock_with_profile(self):
provider_config = AnyProviderConfig.for_bedrock(config)

assert provider_config.api_key == "profile:test-profile"
# Note: base_url is None because _get_base_url doesn't get "Bedrock" name parameter
assert provider_config.base_url is None
assert provider_config.base_url == "us-east-1"

def test_for_bedrock_with_credentials(self):
"""Test Bedrock configuration with AWS credentials."""
Expand All @@ -375,8 +374,7 @@ def test_for_bedrock_with_credentials(self):
provider_config = AnyProviderConfig.for_bedrock(config)

assert provider_config.api_key == "test-access-key:test-secret-key"
# Note: base_url is None because _get_base_url doesn't get "Bedrock" name parameter
assert provider_config.base_url is None
assert provider_config.base_url == "us-west-2"

def test_for_model_openai(self) -> None:
"""Test for_model with OpenAI model."""
Expand Down
Loading
Loading