Skip to content
121 changes: 119 additions & 2 deletions dapr_agents/workflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import sys
import uuid
from datetime import datetime, timezone
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, Sequence

from dapr.ext.workflow import (
DaprWorkflowClient,
Expand All @@ -16,7 +16,7 @@
)
from dapr.ext.workflow.workflow_state import WorkflowState
from durabletask import task as dtask
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator

from dapr_agents.agents.base import ChatClientBase
from dapr_agents.llm.utils.defaults import get_default_llm
Expand Down Expand Up @@ -46,6 +46,14 @@ class WorkflowApp(BaseModel, SignalHandlingMixin):
default=300,
description="Default timeout duration in seconds for workflow tasks.",
)
grpc_max_send_message_length: Optional[int] = Field(
default=None,
description="Maximum message length in bytes for gRPC send operations. Default is 4MB if not specified. Useful for AI workflows with large payloads (e.g., images).",
)
grpc_max_receive_message_length: Optional[int] = Field(
default=None,
description="Maximum message length in bytes for gRPC receive operations. Default is 4MB if not specified. Useful for AI workflows with large payloads (e.g., images).",
)

# Initialized in model_post_init
wf_runtime: Optional[WorkflowRuntime] = Field(
Expand All @@ -68,10 +76,30 @@ class WorkflowApp(BaseModel, SignalHandlingMixin):

model_config = ConfigDict(arbitrary_types_allowed=True)

@model_validator(mode="before")
def validate_grpc_chanell_options(cls, values: Any):
if not isinstance(values, dict):
return values

if values.get("grpc_max_send_message_length") is not None:
if values["grpc_max_send_message_length"] < 0:
raise ValueError("grpc_max_send_message_length must be greater than 0")

if values.get("grpc_max_receive_message_length") is not None:
if values["grpc_max_receive_message_length"] < 0:
raise ValueError(
"grpc_max_receive_message_length must be greater than 0"
)

return values

def model_post_init(self, __context: Any) -> None:
"""
Initialize the Dapr workflow runtime and register tasks & workflows.
"""
if self.grpc_max_send_message_length or self.grpc_max_receive_message_length:
self._configure_grpc_channel_options()

# Initialize LLM first
if self.llm is None:
self.llm = get_default_llm()
Expand All @@ -92,6 +120,95 @@ def model_post_init(self, __context: Any) -> None:

super().model_post_init(__context)

def _configure_grpc_channel_options(self) -> None:
"""
Configure gRPC channel options before workflow runtime initialization.
This patches the durabletask internal channel factory to support custom message size limits.

This is particularly useful for AI-powered workflows that may need to handle large payloads
such as images, which can exceed the default 4MB gRPC message size limit.
"""
try:
import grpc
from durabletask.internal import shared

# Create custom options list
options = []
if self.grpc_max_send_message_length:
options.append(
("grpc.max_send_message_length", self.grpc_max_send_message_length)
)
logger.debug(
f"Configured gRPC max_send_message_length: {self.grpc_max_send_message_length} bytes ({self.grpc_max_send_message_length / (1024 * 1024):.2f} MB)"
)
if self.grpc_max_receive_message_length:
options.append(
(
"grpc.max_receive_message_length",
self.grpc_max_receive_message_length,
)
)
logger.debug(
f"Configured gRPC max_receive_message_length: {self.grpc_max_receive_message_length} bytes ({self.grpc_max_receive_message_length / (1024 * 1024):.2f} MB)"
)

# Patch the function to include our custom options
def get_grpc_channel_with_options(
host_address: Optional[str],
secure_channel: bool = False,
interceptors: Optional[Sequence["grpc.ClientInterceptor"]] = None,
):
# This is a copy of the original get_grpc_channel function in durabletask.internal.shared at
# https://github.com/dapr/durabletask-python/blob/7070cb07d07978d079f8c099743ee4a66ae70e05/durabletask/internal/shared.py#L30C1-L61C19
# but with my option overrides applied above.
if host_address is None:
host_address = shared.get_default_host_address()

for protocol in getattr(shared, "SECURE_PROTOCOLS", []):
if host_address.lower().startswith(protocol):
secure_channel = True
# remove the protocol from the host name
host_address = host_address[len(protocol) :]
break

for protocol in getattr(shared, "INSECURE_PROTOCOLS", []):
if host_address.lower().startswith(protocol):
secure_channel = False
# remove the protocol from the host name
host_address = host_address[len(protocol) :]
break

# Create the base channel
if secure_channel:
credentials = grpc.ssl_channel_credentials()
channel = grpc.secure_channel(
host_address, credentials, options=options
)
else:
channel = grpc.insecure_channel(host_address, options=options)

# Apply interceptors ONLY if they exist
if interceptors:
channel = grpc.intercept_channel(channel, *interceptors)

return channel

# Replace the function
shared.get_grpc_channel = get_grpc_channel_with_options

logger.debug(
"Successfully patched durabletask gRPC channel factory with custom options"
)

except ImportError as e:
logger.error(
f"Failed to import required modules for gRPC configuration: {e}"
)
raise
except Exception as e:
logger.error(f"Failed to configure gRPC channel options: {e}")
raise

def graceful_shutdown(self) -> None:
"""
Perform graceful shutdown operations for the WorkflowApp.
Expand Down
212 changes: 212 additions & 0 deletions tests/workflow/test_grpc_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,212 @@
"""Tests for gRPC configuration in WorkflowApp."""
import pytest
from unittest.mock import MagicMock, patch, call
import types
from dapr_agents.workflow.base import WorkflowApp


@pytest.fixture
def mock_workflow_dependencies():
"""Mock all the dependencies needed for WorkflowApp initialization."""
with patch("dapr_agents.workflow.base.WorkflowRuntime") as mock_runtime, patch(
"dapr_agents.workflow.base.DaprWorkflowClient"
) as mock_client, patch(
"dapr_agents.workflow.base.get_default_llm"
) as mock_llm, patch.object(
WorkflowApp, "start_runtime"
) as mock_start, patch.object(
WorkflowApp, "setup_signal_handlers"
) as mock_handlers:
mock_runtime_instance = MagicMock()
mock_runtime.return_value = mock_runtime_instance

mock_client_instance = MagicMock()
mock_client.return_value = mock_client_instance

mock_llm_instance = MagicMock()
mock_llm.return_value = mock_llm_instance

yield {
"runtime": mock_runtime,
"runtime_instance": mock_runtime_instance,
"client": mock_client,
"client_instance": mock_client_instance,
"llm": mock_llm,
"llm_instance": mock_llm_instance,
"start_runtime": mock_start,
"signal_handlers": mock_handlers,
}


def test_workflow_app_without_grpc_config(mock_workflow_dependencies):
"""Test that WorkflowApp initializes without gRPC configuration."""
# Create WorkflowApp without gRPC config
app = WorkflowApp()

# Verify the app was created
assert app is not None
assert app.grpc_max_send_message_length is None
assert app.grpc_max_receive_message_length is None

# Verify runtime and client were initialized
assert app.wf_runtime is not None
assert app.wf_client is not None


def test_workflow_app_with_grpc_config(mock_workflow_dependencies):
"""Test that WorkflowApp initializes with gRPC configuration."""
# Mock the grpc module and durabletask shared module
mock_grpc = MagicMock()
mock_shared = MagicMock()
mock_channel = MagicMock()

# Set up the mock channel
mock_grpc.insecure_channel.return_value = mock_channel
mock_shared.get_grpc_channel = MagicMock()

with patch.dict(
"sys.modules",
{
"grpc": mock_grpc,
"durabletask.internal.shared": mock_shared,
},
):
# Create WorkflowApp with gRPC config (16MB)
app = WorkflowApp(
grpc_max_send_message_length=16 * 1024 * 1024, # 16MB
grpc_max_receive_message_length=16 * 1024 * 1024, # 16MB
)

# Verify the configuration was set
assert app.grpc_max_send_message_length == 16 * 1024 * 1024
assert app.grpc_max_receive_message_length == 16 * 1024 * 1024

# Verify runtime and client were initialized
assert app.wf_runtime is not None
assert app.wf_client is not None


def test_configure_grpc_channel_options_is_called(mock_workflow_dependencies):
"""Test that _configure_grpc_channel_options is called when gRPC config is provided."""
with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure:
# Create WorkflowApp with gRPC config
WorkflowApp(
grpc_max_send_message_length=8 * 1024 * 1024, # 8MB
)

# Verify the configuration method was called
mock_configure.assert_called_once()


def test_configure_grpc_channel_options_not_called_without_config(
mock_workflow_dependencies,
):
"""Test that _configure_grpc_channel_options is not called without gRPC config."""
with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure:
# Create WorkflowApp without gRPC config
WorkflowApp()

# Verify the configuration method was NOT called
mock_configure.assert_not_called()


def test_grpc_channel_patching():
"""Test that the gRPC channel factory is properly patched with custom options."""
# Mock the grpc module and durabletask shared module
mock_grpc = MagicMock()
mock_shared = MagicMock()
mock_channel = MagicMock()

# Set up the mock channel
mock_grpc.insecure_channel.return_value = mock_channel

# Keep original reference
def original_get_grpc_channel(*_, **__):
return "original"

mock_shared.get_grpc_channel = original_get_grpc_channel

# Create dummy package/module structure so 'from durabletask.internal import shared' works
durabletask_module = types.ModuleType("durabletask")
internal_module = types.ModuleType("durabletask.internal")
setattr(durabletask_module, "internal", internal_module)
setattr(internal_module, "shared", mock_shared)

with patch.dict(
"sys.modules",
{
"grpc": mock_grpc,
"durabletask": durabletask_module,
"durabletask.internal": internal_module,
"durabletask.internal.shared": mock_shared,
},
), patch("dapr_agents.workflow.base.WorkflowRuntime"), patch(
"dapr_agents.workflow.base.DaprWorkflowClient"
), patch("dapr_agents.workflow.base.get_default_llm"), patch.object(
WorkflowApp, "start_runtime"
), patch.object(WorkflowApp, "setup_signal_handlers"):
# Create WorkflowApp with gRPC config
max_send = 10 * 1024 * 1024 # 10MB
max_recv = 12 * 1024 * 1024 # 12MB

WorkflowApp(
grpc_max_send_message_length=max_send,
grpc_max_receive_message_length=max_recv,
)

# Confirm get_grpc_channel was overridden
assert callable(mock_shared.get_grpc_channel)
assert mock_shared.get_grpc_channel is not original_get_grpc_channel
assert (
getattr(mock_shared.get_grpc_channel, "__name__", "")
== "get_grpc_channel_with_options"
)

# Call the patched function
test_address = "localhost:50001"
mock_shared.get_grpc_channel(test_address)

# Verify insecure_channel was called with correct options
mock_grpc.insecure_channel.assert_called_once()
call_args = mock_grpc.insecure_channel.call_args

# Check that the address was passed
assert call_args[0][0] == test_address

# Check that options were passed
assert "options" in call_args.kwargs
options = call_args.kwargs["options"]

# Verify options contain our custom message size limits
assert ("grpc.max_send_message_length", max_send) in options
assert ("grpc.max_receive_message_length", max_recv) in options


def test_grpc_config_with_only_send_limit(mock_workflow_dependencies):
"""Test gRPC configuration with only send limit set."""
with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure:
app = WorkflowApp(
grpc_max_send_message_length=20 * 1024 * 1024, # 20MB
)

# Verify configuration was called
mock_configure.assert_called_once()

# Verify only send limit was set
assert app.grpc_max_send_message_length == 20 * 1024 * 1024
assert app.grpc_max_receive_message_length is None


def test_grpc_config_with_only_receive_limit(mock_workflow_dependencies):
"""Test gRPC configuration with only receive limit set."""
with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure:
app = WorkflowApp(
grpc_max_receive_message_length=24 * 1024 * 1024, # 24MB
)

# Verify configuration was called
mock_configure.assert_called_once()

# Verify only receive limit was set
assert app.grpc_max_send_message_length is None
assert app.grpc_max_receive_message_length == 24 * 1024 * 1024