diff --git a/dapr_agents/agents/base.py b/dapr_agents/agents/base.py index 07cf8751..da2292b4 100644 --- a/dapr_agents/agents/base.py +++ b/dapr_agents/agents/base.py @@ -12,6 +12,7 @@ AgentRegistryConfig, AgentStateConfig, AgentExecutionConfig, + WorkflowGrpcOptions, DEFAULT_AGENT_WORKFLOW_BUNDLE, ) from dapr_agents.agents.prompting import AgentProfileConfig, PromptingAgentBase @@ -65,6 +66,7 @@ def __init__( tools: Optional[Iterable[Any]] = None, # Metadata agent_metadata: Optional[Dict[str, Any]] = None, + workflow_grpc: Optional[WorkflowGrpcOptions] = None, # Execution execution: Optional[AgentExecutionConfig] = None, ) -> None: @@ -95,6 +97,7 @@ def __init__( tools: Optional tool callables or `AgentTool` instances. agent_metadata: Extra metadata to store in the registry. + workflow_grpc: Optional gRPC overrides for the workflow runtime channel. """ # Resolve and validate profile (ensures non-empty name). resolved_profile = self._build_profile( @@ -118,6 +121,7 @@ def __init__( base_metadata=base_metadata, max_etag_attempts=max_etag_attempts, default_bundle=DEFAULT_AGENT_WORKFLOW_BUNDLE, + workflow_grpc_options=workflow_grpc, ) # ----------------------------- diff --git a/dapr_agents/agents/components.py b/dapr_agents/agents/components.py index fe0504f2..fa1e6798 100644 --- a/dapr_agents/agents/components.py +++ b/dapr_agents/agents/components.py @@ -14,6 +14,7 @@ AgentRegistryConfig, AgentStateConfig, DEFAULT_AGENT_WORKFLOW_BUNDLE, + WorkflowGrpcOptions, StateModelBundle, ) from dapr_agents.agents.schemas import AgentWorkflowEntry @@ -44,6 +45,7 @@ def __init__( registry: Optional[AgentRegistryConfig] = None, base_metadata: Optional[Dict[str, Any]] = None, max_etag_attempts: int = 10, + workflow_grpc_options: Optional[WorkflowGrpcOptions] = None, default_bundle: Optional[StateModelBundle] = None, ) -> None: """ @@ -59,6 +61,7 @@ def __init__( default_bundle: Default state schema bundle (injected by agent/orchestrator class). """ self.name = name + self._workflow_grpc_options = workflow_grpc_options # ----------------------------- # Pub/Sub configuration (copy) @@ -179,6 +182,11 @@ def workflow_state(self) -> BaseModel: """Return the in-memory workflow state model (customizable model).""" return self._state_model + @property + def workflow_grpc_options(self) -> Optional[WorkflowGrpcOptions]: + """Return workflow gRPC tuning options if provided.""" + return self._workflow_grpc_options + @property def state(self) -> Dict[str, Any]: """Return the workflow state as a JSON-serializable dict.""" diff --git a/dapr_agents/agents/configs.py b/dapr_agents/agents/configs.py index 98a27c67..348be554 100644 --- a/dapr_agents/agents/configs.py +++ b/dapr_agents/agents/configs.py @@ -54,6 +54,32 @@ class StateModelBundle: ) +@dataclass +class WorkflowGrpcOptions: + """ + Optional overrides for Durable Task gRPC channel limits. + + Allows agents/orchestrators to lift the default ~4 MB message size + ceiling when sending or receiving large payloads through the workflow + runtime channel. + """ + + max_send_message_length: Optional[int] = None + max_receive_message_length: Optional[int] = None + + def __post_init__(self) -> None: + if ( + self.max_send_message_length is not None + and self.max_send_message_length <= 0 + ): + raise ValueError("max_send_message_length must be greater than 0") + if ( + self.max_receive_message_length is not None + and self.max_receive_message_length <= 0 + ): + raise ValueError("max_receive_message_length must be greater than 0") + + @dataclass class AgentStateConfig: """ diff --git a/dapr_agents/agents/durable.py b/dapr_agents/agents/durable.py index a948446b..fee9cb45 100644 --- a/dapr_agents/agents/durable.py +++ b/dapr_agents/agents/durable.py @@ -13,6 +13,7 @@ AgentRegistryConfig, AgentStateConfig, AgentExecutionConfig, + WorkflowGrpcOptions, ) from dapr_agents.agents.prompting import AgentProfileConfig from dapr_agents.agents.schemas import ( @@ -31,6 +32,7 @@ ) from dapr_agents.types.workflow import DaprWorkflowStatus from dapr_agents.workflow.decorators.routers import message_router +from dapr_agents.workflow.utils.grpc import apply_grpc_options from dapr_agents.workflow.utils.pubsub import broadcast_message, send_message_to_agent logger = logging.getLogger(__name__) @@ -71,6 +73,7 @@ def __init__( execution: Optional[AgentExecutionConfig] = None, # Misc agent_metadata: Optional[Dict[str, Any]] = None, + workflow_grpc: Optional[WorkflowGrpcOptions] = None, runtime: Optional[wf.WorkflowRuntime] = None, ) -> None: """ @@ -96,6 +99,7 @@ def __init__( tools: Optional tool callables or `AgentTool` instances. agent_metadata: Extra metadata to publish to the registry. + workflow_grpc: Optional gRPC overrides for the workflow runtime channel. runtime: Optional pre-existing workflow runtime to attach to. """ super().__init__( @@ -112,11 +116,14 @@ def __init__( registry=registry, execution=execution, agent_metadata=agent_metadata, + workflow_grpc=workflow_grpc, llm=llm, tools=tools, prompt_template=prompt_template, ) + apply_grpc_options(self.workflow_grpc_options) + self._runtime: wf.WorkflowRuntime = runtime or wf.WorkflowRuntime() self._runtime_owned = runtime is None self._registered = False diff --git a/dapr_agents/agents/orchestrators/base.py b/dapr_agents/agents/orchestrators/base.py index f4fb3615..b40c8ac2 100644 --- a/dapr_agents/agents/orchestrators/base.py +++ b/dapr_agents/agents/orchestrators/base.py @@ -13,9 +13,11 @@ AgentPubSubConfig, AgentRegistryConfig, AgentStateConfig, + WorkflowGrpcOptions, StateModelBundle, ) from dapr_agents.agents.utils.text_printer import ColorTextFormatter +from dapr_agents.workflow.utils.grpc import apply_grpc_options logger = logging.getLogger(__name__) @@ -40,6 +42,7 @@ def __init__( registry: Optional[AgentRegistryConfig] = None, execution: Optional[AgentExecutionConfig] = None, agent_metadata: Optional[Dict[str, Any]] = None, + workflow_grpc: Optional[WorkflowGrpcOptions] = None, runtime: Optional[wf.WorkflowRuntime] = None, workflow_client: Optional[wf.DaprWorkflowClient] = None, default_bundle: Optional[StateModelBundle] = None, @@ -54,6 +57,7 @@ def __init__( registry: Agent registry configuration for discovery. agent_metadata: Extra metadata to store in the registry; ``orchestrator=True`` is enforced automatically. + workflow_grpc: Optional gRPC overrides for the workflow runtime channel. runtime: Optional pre-existing workflow runtime to attach to. workflow_client: Optional DaprWorkflowClient for dependency injection/testing. default_bundle: Optional state schema bundle (injected by orchestrator subclass). @@ -63,6 +67,7 @@ def __init__( pubsub=pubsub, state=state, registry=registry, + workflow_grpc_options=workflow_grpc, default_bundle=default_bundle, ) @@ -84,6 +89,8 @@ def __init__( ) # Runtime wiring + apply_grpc_options(self.workflow_grpc_options) + self._runtime: wf.WorkflowRuntime = runtime or wf.WorkflowRuntime() self._runtime_owned = runtime is None self._registered = False diff --git a/dapr_agents/agents/orchestrators/llm/base.py b/dapr_agents/agents/orchestrators/llm/base.py index cd8fa97b..f732decd 100644 --- a/dapr_agents/agents/orchestrators/llm/base.py +++ b/dapr_agents/agents/orchestrators/llm/base.py @@ -12,6 +12,7 @@ AgentRegistryConfig, AgentStateConfig, AgentExecutionConfig, + WorkflowGrpcOptions, ) from dapr_agents.agents.orchestrators.base import OrchestratorBase from dapr_agents.agents.orchestrators.llm.configs import build_llm_state_bundle @@ -49,6 +50,7 @@ def __init__( agent_metadata: Optional[Dict[str, Any]] = None, memory: Optional[AgentMemoryConfig] = None, llm: Optional[ChatClientBase] = None, + workflow_grpc: Optional[WorkflowGrpcOptions] = None, runtime: Optional[wf.WorkflowRuntime] = None, workflow_client: Optional[wf.DaprWorkflowClient] = None, ) -> None: @@ -64,6 +66,7 @@ def __init__( agent_metadata (Optional[Dict[str, Any]]): Metadata to store alongside the registry entry. memory (Optional[AgentMemoryConfig]): Memory configuration for the orchestrator. llm (Optional[ChatClientBase]): LLM client instance. + workflow_grpc (Optional[WorkflowGrpcOptions]): gRPC overrides for the workflow runtime channel. runtime (Optional[wf.WorkflowRuntime]): Workflow runtime configuration. workflow_client (Optional[wf.DaprWorkflowClient]): Dapr workflow client. """ @@ -74,6 +77,7 @@ def __init__( registry=registry, execution=execution, agent_metadata=agent_metadata, + workflow_grpc=workflow_grpc, runtime=runtime, workflow_client=workflow_client, default_bundle=build_llm_state_bundle(), diff --git a/dapr_agents/agents/orchestrators/random.py b/dapr_agents/agents/orchestrators/random.py index 346353a8..4514a511 100644 --- a/dapr_agents/agents/orchestrators/random.py +++ b/dapr_agents/agents/orchestrators/random.py @@ -13,6 +13,7 @@ AgentRegistryConfig, AgentStateConfig, AgentExecutionConfig, + WorkflowGrpcOptions, ) from dapr_agents.agents.orchestrators.base import OrchestratorBase from dapr_agents.agents.schemas import ( @@ -49,6 +50,7 @@ def __init__( registry: Optional[AgentRegistryConfig] = None, agent_metadata: Optional[Dict[str, Any]] = None, execution: Optional[AgentExecutionConfig] = None, + workflow_grpc: Optional[WorkflowGrpcOptions] = None, timeout_seconds: int = 60, runtime: Optional[wf.WorkflowRuntime] = None, ) -> None: @@ -59,6 +61,7 @@ def __init__( registry=registry, execution=execution, agent_metadata=agent_metadata, + workflow_grpc=workflow_grpc, runtime=runtime, ) self.timeout = max(1, timeout_seconds) diff --git a/dapr_agents/agents/orchestrators/roundrobin.py b/dapr_agents/agents/orchestrators/roundrobin.py index 63421462..877a48a5 100644 --- a/dapr_agents/agents/orchestrators/roundrobin.py +++ b/dapr_agents/agents/orchestrators/roundrobin.py @@ -12,6 +12,7 @@ AgentRegistryConfig, AgentStateConfig, AgentExecutionConfig, + WorkflowGrpcOptions, ) from dapr_agents.agents.orchestrators.base import OrchestratorBase from dapr_agents.agents.schemas import ( @@ -43,6 +44,7 @@ def __init__( registry: Optional[AgentRegistryConfig] = None, execution: Optional[AgentExecutionConfig] = None, agent_metadata: Optional[Dict[str, Any]] = None, + workflow_grpc: Optional[WorkflowGrpcOptions] = None, timeout_seconds: int = 60, runtime: Optional[wf.WorkflowRuntime] = None, ) -> None: @@ -53,6 +55,7 @@ def __init__( registry=registry, execution=execution, agent_metadata=agent_metadata, + workflow_grpc=workflow_grpc, runtime=runtime, ) self.timeout = max(1, timeout_seconds) diff --git a/dapr_agents/workflow/utils/grpc.py b/dapr_agents/workflow/utils/grpc.py new file mode 100644 index 00000000..cd6a48e0 --- /dev/null +++ b/dapr_agents/workflow/utils/grpc.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +import logging +from typing import Optional, Sequence + +from dapr_agents.agents.configs import WorkflowGrpcOptions + +logger = logging.getLogger(__name__) + + +# This is a copy of the original get_grpc_channel function in durabletask.internal.shared at +# https://github.com/dapr/durabletask-python/blob/7070cb07d07978d079f8c099743ee4a66ae70e05/durabletask/internal/shared.py#L30C1-L61C19 +# but with my option overrides applied above. +def apply_grpc_options(options: Optional[WorkflowGrpcOptions]) -> None: + """ + Patch Durable Task's gRPC channel factory with custom message size limits. + + Durable Task (and therefore Dapr Workflows) creates its gRPC channels via + ``durabletask.internal.shared.get_grpc_channel``. This helper monkey patches + that factory so that subsequent runtime/client instances honour the provided + ``grpc.max_send_message_length`` / ``grpc.max_receive_message_length`` values. + + Users can set either or both options; any non-None value will be applied. + """ + if not options: + return + # Early return if neither option is set + if ( + options.max_send_message_length is None + and options.max_receive_message_length is None + ): + return + + try: + import grpc + from durabletask.internal import shared + except ImportError as exc: + logger.error( + "Failed to import grpc/durabletask for channel configuration: %s", exc + ) + raise + + grpc_options = [] + if options.max_send_message_length: + grpc_options.append( + ("grpc.max_send_message_length", options.max_send_message_length) + ) + if options.max_receive_message_length: + grpc_options.append( + ("grpc.max_receive_message_length", options.max_receive_message_length) + ) + + def get_grpc_channel_with_options( + host_address: Optional[str], + secure_channel: bool = False, + interceptors: Optional[Sequence["grpc.ClientInterceptor"]] = None, + ): + if host_address is None: + host_address = shared.get_default_host_address() + + for protocol in getattr(shared, "SECURE_PROTOCOLS", []): + if host_address.lower().startswith(protocol): + secure_channel = True + host_address = host_address[len(protocol) :] + break + + for protocol in getattr(shared, "INSECURE_PROTOCOLS", []): + if host_address.lower().startswith(protocol): + secure_channel = False + host_address = host_address[len(protocol) :] + break + + if secure_channel: + credentials = grpc.ssl_channel_credentials() + channel = grpc.secure_channel( + host_address, credentials, options=grpc_options + ) + else: + channel = grpc.insecure_channel(host_address, options=grpc_options) + + if interceptors: + channel = grpc.intercept_channel(channel, *interceptors) + + return channel + + shared.get_grpc_channel = get_grpc_channel_with_options + logger.debug( + "Applied gRPC options to durabletask channel factory: %s", dict(grpc_options) + ) diff --git a/tests/workflow/test_grpc_options.py b/tests/workflow/test_grpc_options.py new file mode 100644 index 00000000..9c87b991 --- /dev/null +++ b/tests/workflow/test_grpc_options.py @@ -0,0 +1,104 @@ +from __future__ import annotations + +import types +from unittest.mock import MagicMock, patch + +import pytest + +from dapr_agents.agents.configs import WorkflowGrpcOptions +from dapr_agents.workflow.utils.grpc import apply_grpc_options + + +def create_durabletask_module(shared_module: MagicMock) -> None: + durabletask_module = types.ModuleType("durabletask") + internal_module = types.ModuleType("durabletask.internal") + setattr(durabletask_module, "internal", internal_module) + setattr(internal_module, "shared", shared_module) + import sys + + sys.modules["durabletask"] = durabletask_module + sys.modules["durabletask.internal"] = internal_module + sys.modules["durabletask.internal.shared"] = shared_module + + +@pytest.fixture(autouse=True) +def cleanup_modules(): + import sys + + snapshot = sys.modules.copy() + yield + for key in list(sys.modules.keys()): + if key not in snapshot: + del sys.modules[key] + + +def test_apply_grpc_options_no_options(): + shared = MagicMock() + original = MagicMock() + shared.get_grpc_channel = original + create_durabletask_module(shared) + with patch.dict("sys.modules", {"grpc": MagicMock()}): + apply_grpc_options(None) + assert shared.get_grpc_channel is original + + +def test_apply_grpc_options_only_send(): + grpc_mock = MagicMock() + shared = MagicMock() + shared.get_grpc_channel = MagicMock() + create_durabletask_module(shared) + with patch.dict("sys.modules", {"grpc": grpc_mock}): + opts = WorkflowGrpcOptions(max_send_message_length=16 * 1024 * 1024) + apply_grpc_options(opts) + + assert callable(shared.get_grpc_channel) + shared.get_grpc_channel("localhost:4001") + grpc_mock.insecure_channel.assert_called_once() + call_kwargs = grpc_mock.insecure_channel.call_args.kwargs + assert ("grpc.max_send_message_length", 16 * 1024 * 1024) in call_kwargs[ + "options" + ] + assert "grpc.max_receive_message_length" not in dict(call_kwargs["options"]) + + +def test_apply_grpc_options_only_receive(): + grpc_mock = MagicMock() + shared = MagicMock() + shared.get_grpc_channel = MagicMock() + create_durabletask_module(shared) + with patch.dict("sys.modules", {"grpc": grpc_mock}): + opts = WorkflowGrpcOptions(max_receive_message_length=24 * 1024 * 1024) + apply_grpc_options(opts) + + shared.get_grpc_channel("localhost:4001") + grpc_mock.insecure_channel.assert_called_once() + call_kwargs = grpc_mock.insecure_channel.call_args.kwargs + assert ("grpc.max_receive_message_length", 24 * 1024 * 1024) in call_kwargs[ + "options" + ] + assert "grpc.max_send_message_length" not in dict(call_kwargs["options"]) + + +def test_apply_grpc_options_patch_occurs(): + grpc_mock = MagicMock() + shared = MagicMock() + original = MagicMock() + shared.get_grpc_channel = original + create_durabletask_module(shared) + + with patch.dict("sys.modules", {"grpc": grpc_mock}): + opts = WorkflowGrpcOptions( + max_send_message_length=8 * 1024 * 1024, + max_receive_message_length=12 * 1024 * 1024, + ) + apply_grpc_options(opts) + + assert callable(shared.get_grpc_channel) + assert shared.get_grpc_channel is not original + shared.get_grpc_channel("localhost:50001") + + grpc_mock.insecure_channel.assert_called_once() + kwargs = grpc_mock.insecure_channel.call_args.kwargs + options = dict(kwargs["options"]) + assert options["grpc.max_send_message_length"] == 8 * 1024 * 1024 + assert options["grpc.max_receive_message_length"] == 12 * 1024 * 1024