Skip to content

Commit 93f4477

Browse files
authored
Merge branch 'main' into fix-revert-so-easier-for-big-pr
2 parents 49225ac + 56b2777 commit 93f4477

File tree

32 files changed

+2132
-1127
lines changed

32 files changed

+2132
-1127
lines changed

dapr_agents/document/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from typing import TYPE_CHECKING
2+
13
from .embedder import NVIDIAEmbedder, OpenAIEmbedder, SentenceTransformerEmbedder
24
from .fetcher import ArxivFetcher
35
from .reader import PyMuPDFReader, PyPDFReader
4-
from .splitter import TextSplitter
6+
7+
if TYPE_CHECKING:
8+
from .splitter import TextSplitter
59

610
__all__ = [
711
"ArxivFetcher",
@@ -12,3 +16,12 @@
1216
"SentenceTransformerEmbedder",
1317
"NVIDIAEmbedder",
1418
]
19+
20+
21+
def __getattr__(name: str):
22+
"""Lazy import for optional dependencies."""
23+
if name == "TextSplitter":
24+
from .splitter import TextSplitter
25+
26+
return TextSplitter
27+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

dapr_agents/llm/utils/request.py

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from typing import Dict, Any, Optional, List, Type, Union, Iterable, Literal
2-
from dapr_agents.prompt.prompty import Prompty, PromptyHelper
3-
from dapr_agents.types.message import BaseMessage
4-
from dapr_agents.llm.utils.structure import StructureHandler
5-
from dapr_agents.tool.utils.tool import ToolHelper
1+
import logging
2+
from typing import Any, Dict, Iterable, List, Literal, Optional, Type, Union
3+
64
from pydantic import BaseModel, ValidationError
7-
from dapr_agents.tool.base import AgentTool
85

9-
import logging
6+
from dapr_agents.llm.utils.structure import StructureHandler
7+
from dapr_agents.prompt.prompty import Prompty, PromptyHelper
8+
from dapr_agents.tool.base import AgentTool
9+
from dapr_agents.tool.utils.tool import ToolHelper
10+
from dapr_agents.types.message import BaseMessage
1011

1112
logger = logging.getLogger(__name__)
1213

@@ -100,46 +101,38 @@ def process_params(
100101
Prepare request parameters for the language model.
101102
102103
Args:
103-
params: Parameters for the request.
104-
llm_provider: The LLM provider to use (e.g., 'openai').
105-
tools: List of tools to include in the request.
106-
response_format: Either a Pydantic model (for function calling)
107-
or a JSON Schema definition/dict (for raw JSON structured output).
108-
structured_mode: The mode of structured output: 'json' or 'function_call'.
109-
Defaults to 'json'.
104+
params: Raw request params (messages/inputs, model, etc.).
105+
llm_provider: Provider key, e.g. "openai", "dapr".
106+
tools: Tools to expose to the model (AgentTool or already-shaped dicts).
107+
response_format:
108+
- If structured_mode == "json": a JSON Schema dict or a Pydantic model
109+
(we'll convert) to request raw JSON output.
110+
- If structured_mode == "function_call": a Pydantic model describing
111+
the function/tool signature for model-side function calling.
112+
structured_mode: "json" for raw JSON structured output,
113+
"function_call" for tool/function calling.
110114
111115
Returns:
112-
Dict[str, Any]: Prepared request parameters.
116+
A params dict ready for the target provider.
113117
"""
118+
119+
# Tools
114120
if tools:
115121
logger.info("Tools are available in the request.")
116-
# Convert AgentTool objects to dict format for the provider
117-
tool_dicts = []
118-
for tool in tools:
119-
if isinstance(tool, AgentTool):
120-
tool_dicts.append(
121-
ToolHelper.format_tool(tool, tool_format=llm_provider)
122-
)
123-
else:
124-
tool_dicts.append(
125-
ToolHelper.format_tool(tool, tool_format=llm_provider)
126-
)
127-
params["tools"] = tool_dicts
122+
params["tools"] = [
123+
ToolHelper.format_tool(t, tool_format=llm_provider) for t in tools
124+
]
128125

126+
# Structured output
129127
if response_format:
130-
logger.info(f"Structured Mode Activated! Mode={structured_mode}.")
131-
# Add system message for JSON formatting
132-
# This is necessary for the response formatting of the data to work correctly when a user has a function call response format.
133-
inputs = params.get("inputs", [])
134-
inputs.insert(
135-
0,
136-
{
137-
"role": "system",
138-
"content": "You must format your response as a valid JSON object matching the provided schema. Do not include any explanatory text or markdown formatting.",
139-
},
140-
)
141-
params["inputs"] = inputs
128+
logger.info(f"Structured Mode Activated! mode={structured_mode}")
129+
130+
# If we're on Dapr, we cannot rely on OpenAI-style `response_format`.
131+
# Add a small system nudge to enforce JSON-only output so we can parse reliably.
132+
if llm_provider == "dapr":
133+
params = StructureHandler.ensure_json_only_system_prompt(params)
142134

135+
# Generate provider-specific request params
143136
params = StructureHandler.generate_request(
144137
response_format=response_format,
145138
llm_provider=llm_provider,

dapr_agents/llm/utils/structure.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,3 +627,30 @@ def validate_against_signature(result: Any, expected_type: Any) -> Any:
627627
return adapter.validate_python(result)
628628
except ValidationError as e:
629629
raise TypeError(f"Validation failed for type {expected_type}: {e}")
630+
631+
@staticmethod
632+
def ensure_json_only_system_prompt(params: Dict[str, Any]) -> Dict[str, Any]:
633+
"""
634+
Dapr's chat client (today) does NOT forward OpenAI-style `response_format`
635+
(e.g., {"type":"json_schema", ...}). That means the model won't be hard-constrained
636+
to your schema. As a fallback, we prepend a system message that instructs the
637+
model to return strict JSON so downstream parsing doesn't break.
638+
639+
Note:
640+
- Dapr uses "inputs" (not "messages") for the message array.
641+
- If "inputs" isn't present (future providers), we fall back to "messages".
642+
"""
643+
collection_key = "inputs" if "inputs" in params else "messages"
644+
msgs = list(params.get(collection_key, []))
645+
msgs.insert(
646+
0,
647+
{
648+
"role": "system",
649+
"content": (
650+
"Return ONLY a valid JSON object that matches the provided schema. "
651+
"No markdown, no code fences, no explanations—JSON object only."
652+
),
653+
},
654+
)
655+
params[collection_key] = msgs
656+
return params

dapr_agents/workflow/base.py

Lines changed: 119 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88
import uuid
99
from datetime import datetime, timezone
10-
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union
10+
from typing import Any, Callable, Dict, List, Optional, TypeVar, Union, Sequence
1111

1212
from dapr.ext.workflow import (
1313
DaprWorkflowClient,
@@ -16,7 +16,7 @@
1616
)
1717
from dapr.ext.workflow.workflow_state import WorkflowState
1818
from durabletask import task as dtask
19-
from pydantic import BaseModel, ConfigDict, Field
19+
from pydantic import BaseModel, ConfigDict, Field, model_validator
2020

2121
from dapr_agents.agents.base import ChatClientBase
2222
from dapr_agents.llm.utils.defaults import get_default_llm
@@ -46,6 +46,14 @@ class WorkflowApp(BaseModel, SignalHandlingMixin):
4646
default=300,
4747
description="Default timeout duration in seconds for workflow tasks.",
4848
)
49+
grpc_max_send_message_length: Optional[int] = Field(
50+
default=None,
51+
description="Maximum message length in bytes for gRPC send operations. Default is 4MB if not specified. Useful for AI workflows with large payloads (e.g., images).",
52+
)
53+
grpc_max_receive_message_length: Optional[int] = Field(
54+
default=None,
55+
description="Maximum message length in bytes for gRPC receive operations. Default is 4MB if not specified. Useful for AI workflows with large payloads (e.g., images).",
56+
)
4957

5058
# Initialized in model_post_init
5159
wf_runtime: Optional[WorkflowRuntime] = Field(
@@ -68,10 +76,30 @@ class WorkflowApp(BaseModel, SignalHandlingMixin):
6876

6977
model_config = ConfigDict(arbitrary_types_allowed=True)
7078

79+
@model_validator(mode="before")
80+
def validate_grpc_chanell_options(cls, values: Any):
81+
if not isinstance(values, dict):
82+
return values
83+
84+
if values.get("grpc_max_send_message_length") is not None:
85+
if values["grpc_max_send_message_length"] < 0:
86+
raise ValueError("grpc_max_send_message_length must be greater than 0")
87+
88+
if values.get("grpc_max_receive_message_length") is not None:
89+
if values["grpc_max_receive_message_length"] < 0:
90+
raise ValueError(
91+
"grpc_max_receive_message_length must be greater than 0"
92+
)
93+
94+
return values
95+
7196
def model_post_init(self, __context: Any) -> None:
7297
"""
7398
Initialize the Dapr workflow runtime and register tasks & workflows.
7499
"""
100+
if self.grpc_max_send_message_length or self.grpc_max_receive_message_length:
101+
self._configure_grpc_channel_options()
102+
75103
# Initialize LLM first
76104
if self.llm is None:
77105
self.llm = get_default_llm()
@@ -92,6 +120,95 @@ def model_post_init(self, __context: Any) -> None:
92120

93121
super().model_post_init(__context)
94122

123+
def _configure_grpc_channel_options(self) -> None:
124+
"""
125+
Configure gRPC channel options before workflow runtime initialization.
126+
This patches the durabletask internal channel factory to support custom message size limits.
127+
128+
This is particularly useful for AI-powered workflows that may need to handle large payloads
129+
such as images, which can exceed the default 4MB gRPC message size limit.
130+
"""
131+
try:
132+
import grpc
133+
from durabletask.internal import shared
134+
135+
# Create custom options list
136+
options = []
137+
if self.grpc_max_send_message_length:
138+
options.append(
139+
("grpc.max_send_message_length", self.grpc_max_send_message_length)
140+
)
141+
logger.debug(
142+
f"Configured gRPC max_send_message_length: {self.grpc_max_send_message_length} bytes ({self.grpc_max_send_message_length / (1024 * 1024):.2f} MB)"
143+
)
144+
if self.grpc_max_receive_message_length:
145+
options.append(
146+
(
147+
"grpc.max_receive_message_length",
148+
self.grpc_max_receive_message_length,
149+
)
150+
)
151+
logger.debug(
152+
f"Configured gRPC max_receive_message_length: {self.grpc_max_receive_message_length} bytes ({self.grpc_max_receive_message_length / (1024 * 1024):.2f} MB)"
153+
)
154+
155+
# Patch the function to include our custom options
156+
def get_grpc_channel_with_options(
157+
host_address: Optional[str],
158+
secure_channel: bool = False,
159+
interceptors: Optional[Sequence["grpc.ClientInterceptor"]] = None,
160+
):
161+
# This is a copy of the original get_grpc_channel function in durabletask.internal.shared at
162+
# https://github.com/dapr/durabletask-python/blob/7070cb07d07978d079f8c099743ee4a66ae70e05/durabletask/internal/shared.py#L30C1-L61C19
163+
# but with my option overrides applied above.
164+
if host_address is None:
165+
host_address = shared.get_default_host_address()
166+
167+
for protocol in getattr(shared, "SECURE_PROTOCOLS", []):
168+
if host_address.lower().startswith(protocol):
169+
secure_channel = True
170+
# remove the protocol from the host name
171+
host_address = host_address[len(protocol) :]
172+
break
173+
174+
for protocol in getattr(shared, "INSECURE_PROTOCOLS", []):
175+
if host_address.lower().startswith(protocol):
176+
secure_channel = False
177+
# remove the protocol from the host name
178+
host_address = host_address[len(protocol) :]
179+
break
180+
181+
# Create the base channel
182+
if secure_channel:
183+
credentials = grpc.ssl_channel_credentials()
184+
channel = grpc.secure_channel(
185+
host_address, credentials, options=options
186+
)
187+
else:
188+
channel = grpc.insecure_channel(host_address, options=options)
189+
190+
# Apply interceptors ONLY if they exist
191+
if interceptors:
192+
channel = grpc.intercept_channel(channel, *interceptors)
193+
194+
return channel
195+
196+
# Replace the function
197+
shared.get_grpc_channel = get_grpc_channel_with_options
198+
199+
logger.debug(
200+
"Successfully patched durabletask gRPC channel factory with custom options"
201+
)
202+
203+
except ImportError as e:
204+
logger.error(
205+
f"Failed to import required modules for gRPC configuration: {e}"
206+
)
207+
raise
208+
except Exception as e:
209+
logger.error(f"Failed to configure gRPC channel options: {e}")
210+
raise
211+
95212
def graceful_shutdown(self) -> None:
96213
"""
97214
Perform graceful shutdown operations for the WorkflowApp.
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
from .core import task, workflow
22
from .fastapi import route
33
from .messaging import message_router
4+
from .activities import llm_activity, agent_activity
45

5-
__all__ = ["workflow", "task", "route", "message_router"]
6+
__all__ = [
7+
"workflow",
8+
"task",
9+
"route",
10+
"message_router",
11+
"llm_activity",
12+
"agent_activity",
13+
]

0 commit comments

Comments
 (0)