Skip to content

Commit 2f1ac94

Browse files
committed
Merge remote-tracking branch 'origin/main' into cyb3rward0g/llm-agent-activities
2 parents 1a16a63 + 6db5130 commit 2f1ac94

File tree

8 files changed

+356
-1188
lines changed

8 files changed

+356
-1188
lines changed

dapr_agents/workflow/base.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
import uuid
99
from datetime import datetime, timezone
10-
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
10+
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, Sequence
1111

1212
from dapr.ext.workflow import (
1313
DaprWorkflowClient,
@@ -16,7 +16,7 @@
1616
)
1717
from dapr.ext.workflow.workflow_state import WorkflowState
1818
from durabletask import task as dtask
19-
from pydantic import BaseModel, ConfigDict, Field
19+
from pydantic import BaseModel, ConfigDict, Field, model_validator
2020

2121
from dapr_agents.agents.base import ChatClientBase
2222
from dapr_agents.llm.utils.defaults import get_default_llm
@@ -46,6 +46,14 @@ class WorkflowApp(BaseModel, SignalHandlingMixin):
4646
default=300,
4747
description="Default timeout duration in seconds for workflow tasks.",
4848
)
49+
grpc_max_send_message_length: Optional[int] = Field(
50+
default=None,
51+
description="Maximum message length in bytes for gRPC send operations. Default is 4MB if not specified. Useful for AI workflows with large payloads (e.g., images).",
52+
)
53+
grpc_max_receive_message_length: Optional[int] = Field(
54+
default=None,
55+
description="Maximum message length in bytes for gRPC receive operations. Default is 4MB if not specified. Useful for AI workflows with large payloads (e.g., images).",
56+
)
4957

5058
# Initialized in model_post_init
5159
wf_runtime: Optional[WorkflowRuntime] = Field(
@@ -68,10 +76,30 @@ class WorkflowApp(BaseModel, SignalHandlingMixin):
6876

6977
model_config = ConfigDict(arbitrary_types_allowed=True)
7078

79+
@model_validator(mode="before")
80+
def validate_grpc_chanell_options(cls, values: Any):
81+
if not isinstance(values, dict):
82+
return values
83+
84+
if values.get("grpc_max_send_message_length") is not None:
85+
if values["grpc_max_send_message_length"] < 0:
86+
raise ValueError("grpc_max_send_message_length must be greater than 0")
87+
88+
if values.get("grpc_max_receive_message_length") is not None:
89+
if values["grpc_max_receive_message_length"] < 0:
90+
raise ValueError(
91+
"grpc_max_receive_message_length must be greater than 0"
92+
)
93+
94+
return values
95+
7196
def model_post_init(self, __context: Any) -> None:
7297
"""
7398
Initialize the Dapr workflow runtime and register tasks & workflows.
7499
"""
100+
if self.grpc_max_send_message_length or self.grpc_max_receive_message_length:
101+
self._configure_grpc_channel_options()
102+
75103
# Initialize LLM first
76104
if self.llm is None:
77105
self.llm = get_default_llm()
@@ -92,6 +120,95 @@ def model_post_init(self, __context: Any) -> None:
92120

93121
super().model_post_init(__context)
94122

123+
def _configure_grpc_channel_options(self) -> None:
124+
"""
125+
Configure gRPC channel options before workflow runtime initialization.
126+
This patches the durabletask internal channel factory to support custom message size limits.
127+
128+
This is particularly useful for AI-powered workflows that may need to handle large payloads
129+
such as images, which can exceed the default 4MB gRPC message size limit.
130+
"""
131+
try:
132+
import grpc
133+
from durabletask.internal import shared
134+
135+
# Create custom options list
136+
options = []
137+
if self.grpc_max_send_message_length:
138+
options.append(
139+
("grpc.max_send_message_length", self.grpc_max_send_message_length)
140+
)
141+
logger.debug(
142+
f"Configured gRPC max_send_message_length: {self.grpc_max_send_message_length} bytes ({self.grpc_max_send_message_length / (1024 * 1024):.2f} MB)"
143+
)
144+
if self.grpc_max_receive_message_length:
145+
options.append(
146+
(
147+
"grpc.max_receive_message_length",
148+
self.grpc_max_receive_message_length,
149+
)
150+
)
151+
logger.debug(
152+
f"Configured gRPC max_receive_message_length: {self.grpc_max_receive_message_length} bytes ({self.grpc_max_receive_message_length / (1024 * 1024):.2f} MB)"
153+
)
154+
155+
# Patch the function to include our custom options
156+
def get_grpc_channel_with_options(
157+
host_address: Optional[str],
158+
secure_channel: bool = False,
159+
interceptors: Optional[Sequence["grpc.ClientInterceptor"]] = None,
160+
):
161+
# This is a copy of the original get_grpc_channel function in durabletask.internal.shared at
162+
# https://github.com/dapr/durabletask-python/blob/7070cb07d07978d079f8c099743ee4a66ae70e05/durabletask/internal/shared.py#L30C1-L61C19
163+
# but with my option overrides applied above.
164+
if host_address is None:
165+
host_address = shared.get_default_host_address()
166+
167+
for protocol in getattr(shared, "SECURE_PROTOCOLS", []):
168+
if host_address.lower().startswith(protocol):
169+
secure_channel = True
170+
# remove the protocol from the host name
171+
host_address = host_address[len(protocol) :]
172+
break
173+
174+
for protocol in getattr(shared, "INSECURE_PROTOCOLS", []):
175+
if host_address.lower().startswith(protocol):
176+
secure_channel = False
177+
# remove the protocol from the host name
178+
host_address = host_address[len(protocol) :]
179+
break
180+
181+
# Create the base channel
182+
if secure_channel:
183+
credentials = grpc.ssl_channel_credentials()
184+
channel = grpc.secure_channel(
185+
host_address, credentials, options=options
186+
)
187+
else:
188+
channel = grpc.insecure_channel(host_address, options=options)
189+
190+
# Apply interceptors ONLY if they exist
191+
if interceptors:
192+
channel = grpc.intercept_channel(channel, *interceptors)
193+
194+
return channel
195+
196+
# Replace the function
197+
shared.get_grpc_channel = get_grpc_channel_with_options
198+
199+
logger.debug(
200+
"Successfully patched durabletask gRPC channel factory with custom options"
201+
)
202+
203+
except ImportError as e:
204+
logger.error(
205+
f"Failed to import required modules for gRPC configuration: {e}"
206+
)
207+
raise
208+
except Exception as e:
209+
logger.error(f"Failed to configure gRPC channel options: {e}")
210+
raise
211+
95212
def graceful_shutdown(self) -> None:
96213
"""
97214
Perform graceful shutdown operations for the WorkflowApp.

0 commit comments

Comments
 (0)