Skip to content

Commit c054655

Browse files
authored
feat: enable grpc config on agent instantiation (#238)
* feat: enable grpc config on agent instantiation Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e ruff Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e flake8 Signed-off-by: Samantha Coyle <[email protected]> * fix: correct test assertion Signed-off-by: Samantha Coyle <[email protected]> * fix: add validations and correct func params Signed-off-by: Samantha Coyle <[email protected]> * style: lint Signed-off-by: Samantha Coyle <[email protected]> * fix: tox -e flake8 Signed-off-by: Samantha Coyle <[email protected]> * fix: style again Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e ruff Signed-off-by: Samantha Coyle <[email protected]> * fix: update for test to be happy Signed-off-by: Samantha Coyle <[email protected]> * style: tox -e ruff Signed-off-by: Samantha Coyle <[email protected]> * fix: update for test Signed-off-by: Samantha Coyle <[email protected]> --------- Signed-off-by: Samantha Coyle <[email protected]>
1 parent e99a88d commit c054655

File tree

2 files changed

+331
-2
lines changed

2 files changed

+331
-2
lines changed

dapr_agents/workflow/base.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
import uuid
99
from datetime import datetime, timezone
10-
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
10+
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, Sequence
1111

1212
from dapr.ext.workflow import (
1313
DaprWorkflowClient,
@@ -16,7 +16,7 @@
1616
)
1717
from dapr.ext.workflow.workflow_state import WorkflowState
1818
from durabletask import task as dtask
19-
from pydantic import BaseModel, ConfigDict, Field
19+
from pydantic import BaseModel, ConfigDict, Field, model_validator
2020

2121
from dapr_agents.agents.base import ChatClientBase
2222
from dapr_agents.llm.utils.defaults import get_default_llm
@@ -46,6 +46,14 @@ class WorkflowApp(BaseModel, SignalHandlingMixin):
4646
default=300,
4747
description="Default timeout duration in seconds for workflow tasks.",
4848
)
49+
grpc_max_send_message_length: Optional[int] = Field(
50+
default=None,
51+
description="Maximum message length in bytes for gRPC send operations. Default is 4MB if not specified. Useful for AI workflows with large payloads (e.g., images).",
52+
)
53+
grpc_max_receive_message_length: Optional[int] = Field(
54+
default=None,
55+
description="Maximum message length in bytes for gRPC receive operations. Default is 4MB if not specified. Useful for AI workflows with large payloads (e.g., images).",
56+
)
4957

5058
# Initialized in model_post_init
5159
wf_runtime: Optional[WorkflowRuntime] = Field(
@@ -68,10 +76,30 @@ class WorkflowApp(BaseModel, SignalHandlingMixin):
6876

6977
model_config = ConfigDict(arbitrary_types_allowed=True)
7078

79+
@model_validator(mode="before")
80+
def validate_grpc_chanell_options(cls, values: Any):
81+
if not isinstance(values, dict):
82+
return values
83+
84+
if values.get("grpc_max_send_message_length") is not None:
85+
if values["grpc_max_send_message_length"] < 0:
86+
raise ValueError("grpc_max_send_message_length must be greater than 0")
87+
88+
if values.get("grpc_max_receive_message_length") is not None:
89+
if values["grpc_max_receive_message_length"] < 0:
90+
raise ValueError(
91+
"grpc_max_receive_message_length must be greater than 0"
92+
)
93+
94+
return values
95+
7196
def model_post_init(self, __context: Any) -> None:
7297
"""
7398
Initialize the Dapr workflow runtime and register tasks & workflows.
7499
"""
100+
if self.grpc_max_send_message_length or self.grpc_max_receive_message_length:
101+
self._configure_grpc_channel_options()
102+
75103
# Initialize LLM first
76104
if self.llm is None:
77105
self.llm = get_default_llm()
@@ -92,6 +120,95 @@ def model_post_init(self, __context: Any) -> None:
92120

93121
super().model_post_init(__context)
94122

123+
def _configure_grpc_channel_options(self) -> None:
124+
"""
125+
Configure gRPC channel options before workflow runtime initialization.
126+
This patches the durabletask internal channel factory to support custom message size limits.
127+
128+
This is particularly useful for AI-powered workflows that may need to handle large payloads
129+
such as images, which can exceed the default 4MB gRPC message size limit.
130+
"""
131+
try:
132+
import grpc
133+
from durabletask.internal import shared
134+
135+
# Create custom options list
136+
options = []
137+
if self.grpc_max_send_message_length:
138+
options.append(
139+
("grpc.max_send_message_length", self.grpc_max_send_message_length)
140+
)
141+
logger.debug(
142+
f"Configured gRPC max_send_message_length: {self.grpc_max_send_message_length} bytes ({self.grpc_max_send_message_length / (1024 * 1024):.2f} MB)"
143+
)
144+
if self.grpc_max_receive_message_length:
145+
options.append(
146+
(
147+
"grpc.max_receive_message_length",
148+
self.grpc_max_receive_message_length,
149+
)
150+
)
151+
logger.debug(
152+
f"Configured gRPC max_receive_message_length: {self.grpc_max_receive_message_length} bytes ({self.grpc_max_receive_message_length / (1024 * 1024):.2f} MB)"
153+
)
154+
155+
# Patch the function to include our custom options
156+
def get_grpc_channel_with_options(
157+
host_address: Optional[str],
158+
secure_channel: bool = False,
159+
interceptors: Optional[Sequence["grpc.ClientInterceptor"]] = None,
160+
):
161+
# This is a copy of the original get_grpc_channel function in durabletask.internal.shared at
162+
# https://github.com/dapr/durabletask-python/blob/7070cb07d07978d079f8c099743ee4a66ae70e05/durabletask/internal/shared.py#L30C1-L61C19
163+
# but with my option overrides applied above.
164+
if host_address is None:
165+
host_address = shared.get_default_host_address()
166+
167+
for protocol in getattr(shared, "SECURE_PROTOCOLS", []):
168+
if host_address.lower().startswith(protocol):
169+
secure_channel = True
170+
# remove the protocol from the host name
171+
host_address = host_address[len(protocol) :]
172+
break
173+
174+
for protocol in getattr(shared, "INSECURE_PROTOCOLS", []):
175+
if host_address.lower().startswith(protocol):
176+
secure_channel = False
177+
# remove the protocol from the host name
178+
host_address = host_address[len(protocol) :]
179+
break
180+
181+
# Create the base channel
182+
if secure_channel:
183+
credentials = grpc.ssl_channel_credentials()
184+
channel = grpc.secure_channel(
185+
host_address, credentials, options=options
186+
)
187+
else:
188+
channel = grpc.insecure_channel(host_address, options=options)
189+
190+
# Apply interceptors ONLY if they exist
191+
if interceptors:
192+
channel = grpc.intercept_channel(channel, *interceptors)
193+
194+
return channel
195+
196+
# Replace the function
197+
shared.get_grpc_channel = get_grpc_channel_with_options
198+
199+
logger.debug(
200+
"Successfully patched durabletask gRPC channel factory with custom options"
201+
)
202+
203+
except ImportError as e:
204+
logger.error(
205+
f"Failed to import required modules for gRPC configuration: {e}"
206+
)
207+
raise
208+
except Exception as e:
209+
logger.error(f"Failed to configure gRPC channel options: {e}")
210+
raise
211+
95212
def graceful_shutdown(self) -> None:
96213
"""
97214
Perform graceful shutdown operations for the WorkflowApp.

tests/workflow/test_grpc_config.py

Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
"""Tests for gRPC configuration in WorkflowApp."""
2+
import pytest
3+
from unittest.mock import MagicMock, patch, call
4+
import types
5+
from dapr_agents.workflow.base import WorkflowApp
6+
7+
8+
@pytest.fixture
9+
def mock_workflow_dependencies():
10+
"""Mock all the dependencies needed for WorkflowApp initialization."""
11+
with patch("dapr_agents.workflow.base.WorkflowRuntime") as mock_runtime, patch(
12+
"dapr_agents.workflow.base.DaprWorkflowClient"
13+
) as mock_client, patch(
14+
"dapr_agents.workflow.base.get_default_llm"
15+
) as mock_llm, patch.object(
16+
WorkflowApp, "start_runtime"
17+
) as mock_start, patch.object(
18+
WorkflowApp, "setup_signal_handlers"
19+
) as mock_handlers:
20+
mock_runtime_instance = MagicMock()
21+
mock_runtime.return_value = mock_runtime_instance
22+
23+
mock_client_instance = MagicMock()
24+
mock_client.return_value = mock_client_instance
25+
26+
mock_llm_instance = MagicMock()
27+
mock_llm.return_value = mock_llm_instance
28+
29+
yield {
30+
"runtime": mock_runtime,
31+
"runtime_instance": mock_runtime_instance,
32+
"client": mock_client,
33+
"client_instance": mock_client_instance,
34+
"llm": mock_llm,
35+
"llm_instance": mock_llm_instance,
36+
"start_runtime": mock_start,
37+
"signal_handlers": mock_handlers,
38+
}
39+
40+
41+
def test_workflow_app_without_grpc_config(mock_workflow_dependencies):
42+
"""Test that WorkflowApp initializes without gRPC configuration."""
43+
# Create WorkflowApp without gRPC config
44+
app = WorkflowApp()
45+
46+
# Verify the app was created
47+
assert app is not None
48+
assert app.grpc_max_send_message_length is None
49+
assert app.grpc_max_receive_message_length is None
50+
51+
# Verify runtime and client were initialized
52+
assert app.wf_runtime is not None
53+
assert app.wf_client is not None
54+
55+
56+
def test_workflow_app_with_grpc_config(mock_workflow_dependencies):
57+
"""Test that WorkflowApp initializes with gRPC configuration."""
58+
# Mock the grpc module and durabletask shared module
59+
mock_grpc = MagicMock()
60+
mock_shared = MagicMock()
61+
mock_channel = MagicMock()
62+
63+
# Set up the mock channel
64+
mock_grpc.insecure_channel.return_value = mock_channel
65+
mock_shared.get_grpc_channel = MagicMock()
66+
67+
with patch.dict(
68+
"sys.modules",
69+
{
70+
"grpc": mock_grpc,
71+
"durabletask.internal.shared": mock_shared,
72+
},
73+
):
74+
# Create WorkflowApp with gRPC config (16MB)
75+
app = WorkflowApp(
76+
grpc_max_send_message_length=16 * 1024 * 1024, # 16MB
77+
grpc_max_receive_message_length=16 * 1024 * 1024, # 16MB
78+
)
79+
80+
# Verify the configuration was set
81+
assert app.grpc_max_send_message_length == 16 * 1024 * 1024
82+
assert app.grpc_max_receive_message_length == 16 * 1024 * 1024
83+
84+
# Verify runtime and client were initialized
85+
assert app.wf_runtime is not None
86+
assert app.wf_client is not None
87+
88+
89+
def test_configure_grpc_channel_options_is_called(mock_workflow_dependencies):
90+
"""Test that _configure_grpc_channel_options is called when gRPC config is provided."""
91+
with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure:
92+
# Create WorkflowApp with gRPC config
93+
WorkflowApp(
94+
grpc_max_send_message_length=8 * 1024 * 1024, # 8MB
95+
)
96+
97+
# Verify the configuration method was called
98+
mock_configure.assert_called_once()
99+
100+
101+
def test_configure_grpc_channel_options_not_called_without_config(
102+
mock_workflow_dependencies,
103+
):
104+
"""Test that _configure_grpc_channel_options is not called without gRPC config."""
105+
with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure:
106+
# Create WorkflowApp without gRPC config
107+
WorkflowApp()
108+
109+
# Verify the configuration method was NOT called
110+
mock_configure.assert_not_called()
111+
112+
113+
def test_grpc_channel_patching():
114+
"""Test that the gRPC channel factory is properly patched with custom options."""
115+
# Mock the grpc module and durabletask shared module
116+
mock_grpc = MagicMock()
117+
mock_shared = MagicMock()
118+
mock_channel = MagicMock()
119+
120+
# Set up the mock channel
121+
mock_grpc.insecure_channel.return_value = mock_channel
122+
123+
# Keep original reference
124+
def original_get_grpc_channel(*_, **__):
125+
return "original"
126+
127+
mock_shared.get_grpc_channel = original_get_grpc_channel
128+
129+
# Create dummy package/module structure so 'from durabletask.internal import shared' works
130+
durabletask_module = types.ModuleType("durabletask")
131+
internal_module = types.ModuleType("durabletask.internal")
132+
setattr(durabletask_module, "internal", internal_module)
133+
setattr(internal_module, "shared", mock_shared)
134+
135+
with patch.dict(
136+
"sys.modules",
137+
{
138+
"grpc": mock_grpc,
139+
"durabletask": durabletask_module,
140+
"durabletask.internal": internal_module,
141+
"durabletask.internal.shared": mock_shared,
142+
},
143+
), patch("dapr_agents.workflow.base.WorkflowRuntime"), patch(
144+
"dapr_agents.workflow.base.DaprWorkflowClient"
145+
), patch("dapr_agents.workflow.base.get_default_llm"), patch.object(
146+
WorkflowApp, "start_runtime"
147+
), patch.object(WorkflowApp, "setup_signal_handlers"):
148+
# Create WorkflowApp with gRPC config
149+
max_send = 10 * 1024 * 1024 # 10MB
150+
max_recv = 12 * 1024 * 1024 # 12MB
151+
152+
WorkflowApp(
153+
grpc_max_send_message_length=max_send,
154+
grpc_max_receive_message_length=max_recv,
155+
)
156+
157+
# Confirm get_grpc_channel was overridden
158+
assert callable(mock_shared.get_grpc_channel)
159+
assert mock_shared.get_grpc_channel is not original_get_grpc_channel
160+
assert (
161+
getattr(mock_shared.get_grpc_channel, "__name__", "")
162+
== "get_grpc_channel_with_options"
163+
)
164+
165+
# Call the patched function
166+
test_address = "localhost:50001"
167+
mock_shared.get_grpc_channel(test_address)
168+
169+
# Verify insecure_channel was called with correct options
170+
mock_grpc.insecure_channel.assert_called_once()
171+
call_args = mock_grpc.insecure_channel.call_args
172+
173+
# Check that the address was passed
174+
assert call_args[0][0] == test_address
175+
176+
# Check that options were passed
177+
assert "options" in call_args.kwargs
178+
options = call_args.kwargs["options"]
179+
180+
# Verify options contain our custom message size limits
181+
assert ("grpc.max_send_message_length", max_send) in options
182+
assert ("grpc.max_receive_message_length", max_recv) in options
183+
184+
185+
def test_grpc_config_with_only_send_limit(mock_workflow_dependencies):
186+
"""Test gRPC configuration with only send limit set."""
187+
with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure:
188+
app = WorkflowApp(
189+
grpc_max_send_message_length=20 * 1024 * 1024, # 20MB
190+
)
191+
192+
# Verify configuration was called
193+
mock_configure.assert_called_once()
194+
195+
# Verify only send limit was set
196+
assert app.grpc_max_send_message_length == 20 * 1024 * 1024
197+
assert app.grpc_max_receive_message_length is None
198+
199+
200+
def test_grpc_config_with_only_receive_limit(mock_workflow_dependencies):
201+
"""Test gRPC configuration with only receive limit set."""
202+
with patch.object(WorkflowApp, "_configure_grpc_channel_options") as mock_configure:
203+
app = WorkflowApp(
204+
grpc_max_receive_message_length=24 * 1024 * 1024, # 24MB
205+
)
206+
207+
# Verify configuration was called
208+
mock_configure.assert_called_once()
209+
210+
# Verify only receive limit was set
211+
assert app.grpc_max_send_message_length is None
212+
assert app.grpc_max_receive_message_length == 24 * 1024 * 1024

0 commit comments

Comments
 (0)