Skip to content
64 changes: 64 additions & 0 deletions dapr_agents/workflow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,14 @@ class WorkflowApp(BaseModel, SignalHandlingMixin):
default=300,
description="Default timeout duration in seconds for workflow tasks.",
)
grpc_max_send_message_length: Optional[int] = Field(
default=None,
description="Maximum message length in bytes for gRPC send operations. Default is 4MB if not specified. Useful for AI workflows with large payloads (e.g., images).",
)
grpc_max_receive_message_length: Optional[int] = Field(
default=None,
description="Maximum message length in bytes for gRPC receive operations. Default is 4MB if not specified. Useful for AI workflows with large payloads (e.g., images).",
)

# Initialized in model_post_init
wf_runtime: Optional[WorkflowRuntime] = Field(
Expand All @@ -72,6 +80,9 @@ def model_post_init(self, __context: Any) -> None:
"""
Initialize the Dapr workflow runtime and register tasks & workflows.
"""
if self.grpc_max_send_message_length or self.grpc_max_receive_message_length:
self._configure_grpc_channel_options()

# Initialize LLM first
if self.llm is None:
self.llm = get_default_llm()
Expand All @@ -92,6 +103,59 @@ def model_post_init(self, __context: Any) -> None:

super().model_post_init(__context)

def _configure_grpc_channel_options(self) -> None:
"""
Configure gRPC channel options before workflow runtime initialization.
This patches the durabletask internal channel factory to support custom message size limits.

This is particularly useful for AI-powered workflows that may need to handle large payloads
such as images, which can exceed the default 4MB gRPC message size limit.
"""
try:
import grpc
from durabletask.internal import shared

# Create custom options list
options = []
if self.grpc_max_send_message_length:
options.append(
("grpc.max_send_message_length", self.grpc_max_send_message_length)
)
logger.debug(
f"Configured gRPC max_send_message_length: {self.grpc_max_send_message_length} bytes ({self.grpc_max_send_message_length / (1024 * 1024):.2f} MB)"
)
if self.grpc_max_receive_message_length:
options.append(
(
"grpc.max_receive_message_length",
self.grpc_max_receive_message_length,
)
)
logger.debug(
f"Configured gRPC max_receive_message_length: {self.grpc_max_receive_message_length} bytes ({self.grpc_max_receive_message_length / (1024 * 1024):.2f} MB)"
)

# Patch the function to include our custom options
def get_grpc_channel_with_options(address: str):
"""Custom gRPC channel factory with configured message size limits."""
return grpc.insecure_channel(address, options=options)

# Replace the function
shared.get_grpc_channel = get_grpc_channel_with_options

logger.debug(
"Successfully patched durabletask gRPC channel factory with custom options"
)

except ImportError as e:
logger.error(
f"Failed to import required modules for gRPC configuration: {e}"
)
raise
except Exception as e:
logger.error(f"Failed to configure gRPC channel options: {e}")
raise

def graceful_shutdown(self) -> None:
"""
Perform graceful shutdown operations for the WorkflowApp.
Expand Down
197 changes: 197 additions & 0 deletions tests/workflow/test_grpc_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
"""Tests for gRPC configuration in WorkflowApp."""
import pytest
from unittest.mock import MagicMock, patch, call
from dapr_agents.workflow.base import WorkflowApp


@pytest.fixture
def mock_workflow_dependencies():
"""Mock all the dependencies needed for WorkflowApp initialization."""
with patch("dapr_agents.workflow.base.WorkflowRuntime") as mock_runtime, patch(
"dapr_agents.workflow.base.DaprWorkflowClient"
) as mock_client, patch(
"dapr_agents.workflow.base.get_default_llm"
) as mock_llm, patch.object(
WorkflowApp, "start_runtime"
) as mock_start, patch.object(
WorkflowApp, "setup_signal_handlers"
) as mock_handlers:
mock_runtime_instance = MagicMock()
mock_runtime.return_value = mock_runtime_instance

mock_client_instance = MagicMock()
mock_client.return_value = mock_client_instance

mock_llm_instance = MagicMock()
mock_llm.return_value = mock_llm_instance

yield {
"runtime": mock_runtime,
"runtime_instance": mock_runtime_instance,
"client": mock_client,
"client_instance": mock_client_instance,
"llm": mock_llm,
"llm_instance": mock_llm_instance,
"start_runtime": mock_start,
"signal_handlers": mock_handlers,
}


def test_workflow_app_without_grpc_config(mock_workflow_dependencies):
"""Test that WorkflowApp initializes without gRPC configuration."""
# Create WorkflowApp without gRPC config
app = WorkflowApp()

# Verify the app was created
assert app is not None
assert app.grpc_max_send_message_length is None
assert app.grpc_max_receive_message_length is None

# Verify runtime and client were initialized
assert app.wf_runtime is not None
assert app.wf_client is not None


def test_workflow_app_with_grpc_config(mock_workflow_dependencies):
"""Test that WorkflowApp initializes with gRPC configuration."""
# Mock the grpc module and durabletask shared module
mock_grpc = MagicMock()
mock_shared = MagicMock()
mock_channel = MagicMock()

# Set up the mock channel
mock_grpc.insecure_channel.return_value = mock_channel
mock_shared.get_grpc_channel = MagicMock()

with patch.dict(
"sys.modules",
{
"grpc": mock_grpc,
"durabletask.internal.shared": mock_shared,
},
):
# Create WorkflowApp with gRPC config (16MB)
app = WorkflowApp(
grpc_max_send_message_length=16 * 1024 * 1024, # 16MB
grpc_max_receive_message_length=16 * 1024 * 1024, # 16MB
)

# Verify the configuration was set
assert app.grpc_max_send_message_length == 16 * 1024 * 1024
assert app.grpc_max_receive_message_length == 16 * 1024 * 1024

# Verify runtime and client were initialized
assert app.wf_runtime is not None
assert app.wf_client is not None


def test_configure_grpc_channel_options_is_called(mock_workflow_dependencies):
"""Test that _configure_grpc_channel_options is called when gRPC config is provided."""
with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure:
# Create WorkflowApp with gRPC config
WorkflowApp(
grpc_max_send_message_length=8 * 1024 * 1024, # 8MB
)

# Verify the configuration method was called
mock_configure.assert_called_once()


def test_configure_grpc_channel_options_not_called_without_config(
mock_workflow_dependencies,
):
"""Test that _configure_grpc_channel_options is not called without gRPC config."""
with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure:
# Create WorkflowApp without gRPC config
WorkflowApp()

# Verify the configuration method was NOT called
mock_configure.assert_not_called()


def test_grpc_channel_patching():
"""Test that the gRPC channel factory is properly patched with custom options."""
# Mock the grpc module and durabletask shared module
mock_grpc = MagicMock()
mock_shared = MagicMock()
mock_channel = MagicMock()

# Set up the mock channel
mock_grpc.insecure_channel.return_value = mock_channel

# Keep original reference
original_get_grpc_channel = lambda *_, **__: "original"
mock_shared.get_grpc_channel = original_get_grpc_channel

with patch.dict(
"sys.modules",
{
"grpc": mock_grpc,
"durabletask.internal.shared": mock_shared,
},
), patch("dapr_agents.workflow.base.WorkflowRuntime"), patch(
"dapr_agents.workflow.base.DaprWorkflowClient"
), patch("dapr_agents.workflow.base.get_default_llm"), patch.object(
WorkflowApp, "start_runtime"
), patch.object(WorkflowApp, "setup_signal_handlers"):
# Create WorkflowApp with gRPC config
max_send = 10 * 1024 * 1024 # 10MB
max_recv = 12 * 1024 * 1024 # 12MB

app = WorkflowApp(
grpc_max_send_message_length=max_send,
grpc_max_receive_message_length=max_recv,
)

# Confirm get_grpc_channel was overridden
assert callable(mock_shared.get_grpc_channel)
assert mock_shared.get_grpc_channel != original_get_grpc_channel

# Call the patched function
test_address = "localhost:50001"
mock_shared.get_grpc_channel(test_address)

# Verify insecure_channel was called with correct options
mock_grpc.insecure_channel.assert_called_once()
call_args = mock_grpc.insecure_channel.call_args

# Check that the address was passed
assert call_args[0][0] == test_address

# Check that options were passed
assert "options" in call_args.kwargs
options = call_args.kwargs["options"]

# Verify options contain our custom message size limits
assert ("grpc.max_send_message_length", max_send) in options
assert ("grpc.max_receive_message_length", max_recv) in options


def test_grpc_config_with_only_send_limit(mock_workflow_dependencies):
"""Test gRPC configuration with only send limit set."""
with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure:
app = WorkflowApp(
grpc_max_send_message_length=20 * 1024 * 1024, # 20MB
)

# Verify configuration was called
mock_configure.assert_called_once()

# Verify only send limit was set
assert app.grpc_max_send_message_length == 20 * 1024 * 1024
assert app.grpc_max_receive_message_length is None


def test_grpc_config_with_only_receive_limit(mock_workflow_dependencies):
"""Test gRPC configuration with only receive limit set."""
with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure:
app = WorkflowApp(
grpc_max_receive_message_length=24 * 1024 * 1024, # 24MB
)

# Verify configuration was called
mock_configure.assert_called_once()

# Verify only receive limit was set
assert app.grpc_max_send_message_length is None
assert app.grpc_max_receive_message_length == 24 * 1024 * 1024
Loading