Skip to content

Commit 56b2777

Browse files
authored
Merge pull request #232 from dapr/cyb3rward0g/llm-agent-activities
Enable native Dapr workflows with LLM and Agent decorators
2 parents 6db5130 + 2f1ac94 commit 56b2777

25 files changed

+1784
-258
lines changed

dapr_agents/document/__init__.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from typing import TYPE_CHECKING
2+
13
from .embedder import NVIDIAEmbedder, OpenAIEmbedder, SentenceTransformerEmbedder
24
from .fetcher import ArxivFetcher
35
from .reader import PyMuPDFReader, PyPDFReader
4-
from .splitter import TextSplitter
6+
7+
if TYPE_CHECKING:
8+
from .splitter import TextSplitter
59

610
__all__ = [
711
"ArxivFetcher",
@@ -12,3 +16,12 @@
1216
"SentenceTransformerEmbedder",
1317
"NVIDIAEmbedder",
1418
]
19+
20+
21+
def __getattr__(name: str):
22+
"""Lazy import for optional dependencies."""
23+
if name == "TextSplitter":
24+
from .splitter import TextSplitter
25+
26+
return TextSplitter
27+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")

dapr_agents/llm/utils/request.py

Lines changed: 32 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
1-
from typing import Dict, Any, Optional, List, Type, Union, Iterable, Literal
2-
from dapr_agents.prompt.prompty import Prompty, PromptyHelper
3-
from dapr_agents.types.message import BaseMessage
4-
from dapr_agents.llm.utils.structure import StructureHandler
5-
from dapr_agents.tool.utils.tool import ToolHelper
1+
import logging
2+
from typing import Any, Dict, Iterable, List, Literal, Optional, Type, Union
3+
64
from pydantic import BaseModel, ValidationError
7-
from dapr_agents.tool.base import AgentTool
85

9-
import logging
6+
from dapr_agents.llm.utils.structure import StructureHandler
7+
from dapr_agents.prompt.prompty import Prompty, PromptyHelper
8+
from dapr_agents.tool.base import AgentTool
9+
from dapr_agents.tool.utils.tool import ToolHelper
10+
from dapr_agents.types.message import BaseMessage
1011

1112
logger = logging.getLogger(__name__)
1213

@@ -100,46 +101,38 @@ def process_params(
100101
Prepare request parameters for the language model.
101102
102103
Args:
103-
params: Parameters for the request.
104-
llm_provider: The LLM provider to use (e.g., 'openai').
105-
tools: List of tools to include in the request.
106-
response_format: Either a Pydantic model (for function calling)
107-
or a JSON Schema definition/dict (for raw JSON structured output).
108-
structured_mode: The mode of structured output: 'json' or 'function_call'.
109-
Defaults to 'json'.
104+
params: Raw request params (messages/inputs, model, etc.).
105+
llm_provider: Provider key, e.g. "openai", "dapr".
106+
tools: Tools to expose to the model (AgentTool or already-shaped dicts).
107+
response_format:
108+
- If structured_mode == "json": a JSON Schema dict or a Pydantic model
109+
(we'll convert) to request raw JSON output.
110+
- If structured_mode == "function_call": a Pydantic model describing
111+
the function/tool signature for model-side function calling.
112+
structured_mode: "json" for raw JSON structured output,
113+
"function_call" for tool/function calling.
110114
111115
Returns:
112-
Dict[str, Any]: Prepared request parameters.
116+
A params dict ready for the target provider.
113117
"""
118+
119+
# Tools
114120
if tools:
115121
logger.info("Tools are available in the request.")
116-
# Convert AgentTool objects to dict format for the provider
117-
tool_dicts = []
118-
for tool in tools:
119-
if isinstance(tool, AgentTool):
120-
tool_dicts.append(
121-
ToolHelper.format_tool(tool, tool_format=llm_provider)
122-
)
123-
else:
124-
tool_dicts.append(
125-
ToolHelper.format_tool(tool, tool_format=llm_provider)
126-
)
127-
params["tools"] = tool_dicts
122+
params["tools"] = [
123+
ToolHelper.format_tool(t, tool_format=llm_provider) for t in tools
124+
]
128125

126+
# Structured output
129127
if response_format:
130-
logger.info(f"Structured Mode Activated! Mode={structured_mode}.")
131-
# Add system message for JSON formatting
132-
# This is necessary for the response formatting of the data to work correctly when a user has a function call response format.
133-
inputs = params.get("inputs", [])
134-
inputs.insert(
135-
0,
136-
{
137-
"role": "system",
138-
"content": "You must format your response as a valid JSON object matching the provided schema. Do not include any explanatory text or markdown formatting.",
139-
},
140-
)
141-
params["inputs"] = inputs
128+
logger.info(f"Structured Mode Activated! mode={structured_mode}")
129+
130+
# If we're on Dapr, we cannot rely on OpenAI-style `response_format`.
131+
# Add a small system nudge to enforce JSON-only output so we can parse reliably.
132+
if llm_provider == "dapr":
133+
params = StructureHandler.ensure_json_only_system_prompt(params)
142134

135+
# Generate provider-specific request params
143136
params = StructureHandler.generate_request(
144137
response_format=response_format,
145138
llm_provider=llm_provider,

dapr_agents/llm/utils/structure.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -627,3 +627,30 @@ def validate_against_signature(result: Any, expected_type: Any) -> Any:
627627
return adapter.validate_python(result)
628628
except ValidationError as e:
629629
raise TypeError(f"Validation failed for type {expected_type}: {e}")
630+
631+
@staticmethod
632+
def ensure_json_only_system_prompt(params: Dict[str, Any]) -> Dict[str, Any]:
633+
"""
634+
Dapr's chat client (today) does NOT forward OpenAI-style `response_format`
635+
(e.g., {"type":"json_schema", ...}). That means the model won't be hard-constrained
636+
to your schema. As a fallback, we prepend a system message that instructs the
637+
model to return strict JSON so downstream parsing doesn't break.
638+
639+
Note:
640+
- Dapr uses "inputs" (not "messages") for the message array.
641+
- If "inputs" isn't present (future providers), we fall back to "messages".
642+
"""
643+
collection_key = "inputs" if "inputs" in params else "messages"
644+
msgs = list(params.get(collection_key, []))
645+
msgs.insert(
646+
0,
647+
{
648+
"role": "system",
649+
"content": (
650+
"Return ONLY a valid JSON object that matches the provided schema. "
651+
"No markdown, no code fences, no explanations—JSON object only."
652+
),
653+
},
654+
)
655+
params[collection_key] = msgs
656+
return params
Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
from .core import task, workflow
22
from .fastapi import route
33
from .messaging import message_router
4+
from .activities import llm_activity, agent_activity
45

5-
__all__ = ["workflow", "task", "route", "message_router"]
6+
__all__ = [
7+
"workflow",
8+
"task",
9+
"route",
10+
"message_router",
11+
"llm_activity",
12+
"agent_activity",
13+
]
Lines changed: 193 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,193 @@
1+
from __future__ import annotations
2+
3+
import asyncio
4+
import functools
5+
import inspect
6+
import logging
7+
from typing import Any, Callable, Literal, Optional
8+
9+
from dapr.ext.workflow import WorkflowActivityContext # type: ignore
10+
11+
from dapr_agents.agents.base import AgentBase
12+
from dapr_agents.llm.chat import ChatClientBase
13+
from dapr_agents.workflow.utils.activities import (
14+
build_llm_params,
15+
convert_result,
16+
extract_ctx_and_payload,
17+
format_agent_input,
18+
format_prompt,
19+
normalize_input,
20+
strip_context_parameter,
21+
validate_result,
22+
)
23+
24+
logger = logging.getLogger(__name__)
25+
26+
27+
def llm_activity(
28+
*,
29+
prompt: str,
30+
llm: ChatClientBase,
31+
structured_mode: Literal["json", "function_call"] = "json",
32+
**task_kwargs: Any,
33+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
34+
"""Delegate an activity's implementation to an LLM.
35+
36+
The decorated function's body is not executed directly. Instead:
37+
1) Build a prompt from the activity's signature + `prompt`
38+
2) Call the provided LLM client
39+
3) Validate the result against the activity's return annotation
40+
41+
Args:
42+
prompt: Prompt template (e.g., "Summarize {text} in 3 bullets.")
43+
llm: Chat client capable of `generate(**params)`.
44+
structured_mode: Provider structured output mode ("json" or "function_call").
45+
**task_kwargs: Reserved for future routing/provider knobs.
46+
47+
Returns:
48+
A wrapper suitable to register as a Dapr activity.
49+
50+
Raises:
51+
ValueError: If `prompt` is empty or `llm` is missing.
52+
"""
53+
if not prompt:
54+
raise ValueError("@llm_activity requires a prompt template.")
55+
if llm is None:
56+
raise ValueError("@llm_activity requires an explicit `llm` client instance.")
57+
58+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
59+
if not callable(func):
60+
raise ValueError("@llm_activity must decorate a callable activity.")
61+
62+
original_sig = inspect.signature(func)
63+
activity_sig = strip_context_parameter(original_sig)
64+
effective_structured_mode = task_kwargs.get("structured_mode", structured_mode)
65+
66+
async def _execute(ctx: WorkflowActivityContext, payload: Any = None) -> Any:
67+
"""Run the LLM pipeline inside the worker."""
68+
normalized = (
69+
normalize_input(activity_sig, payload) if payload is not None else {}
70+
)
71+
72+
formatted_prompt = format_prompt(activity_sig, prompt, normalized)
73+
params = build_llm_params(
74+
activity_sig, formatted_prompt, effective_structured_mode
75+
)
76+
77+
raw = llm.generate(**params)
78+
if inspect.isawaitable(raw):
79+
raw = await raw
80+
81+
converted = convert_result(raw)
82+
validated = await validate_result(converted, activity_sig)
83+
return validated
84+
85+
@functools.wraps(func)
86+
def wrapper(*args: Any, **kwargs: Any) -> Any:
87+
"""Sync activity wrapper: execute async pipeline to completion."""
88+
ctx, payload = extract_ctx_and_payload(args, dict(kwargs))
89+
result = _execute(ctx, payload) # coroutine
90+
91+
# If we're in a thread with an active loop, run thread-safely
92+
try:
93+
loop = asyncio.get_running_loop()
94+
except RuntimeError:
95+
loop = None
96+
97+
if loop and loop.is_running():
98+
fut = asyncio.run_coroutine_threadsafe(result, loop)
99+
return fut.result()
100+
101+
# Otherwise create and run a fresh loop
102+
return asyncio.run(result)
103+
104+
# Useful metadata for debugging/inspection
105+
wrapper._is_llm_activity = True # noqa: SLF001
106+
wrapper._llm_activity_config = { # noqa: SLF001
107+
"prompt": prompt,
108+
"structured_mode": effective_structured_mode,
109+
"task_kwargs": task_kwargs,
110+
}
111+
wrapper._original_activity = func # noqa: SLF001
112+
return wrapper
113+
114+
return decorator
115+
116+
117+
def agent_activity(
118+
*,
119+
agent: AgentBase,
120+
prompt: Optional[str] = None,
121+
**task_kwargs: Any,
122+
) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
123+
"""Route an activity through an `AgentBase`.
124+
125+
The agent receives either a formatted `prompt` or a natural-language
126+
rendering of the payload. The result is validated against the activity's return
127+
annotation.
128+
129+
Args:
130+
agent: Agent to run the activity through.
131+
prompt: Optional prompt template for the agent.
132+
**task_kwargs: Reserved for future routing/provider knobs.
133+
134+
Returns:
135+
A wrapper suitable to register as a Dapr activity.
136+
137+
Raises:
138+
ValueError: If `agent` is missing.
139+
"""
140+
if agent is None:
141+
raise ValueError("@agent_activity requires an AgentBase instance.")
142+
143+
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
144+
if not callable(func):
145+
raise ValueError("@agent_activity must decorate a callable activity.")
146+
147+
original_sig = inspect.signature(func)
148+
activity_sig = strip_context_parameter(original_sig)
149+
prompt_template = prompt or ""
150+
151+
async def _execute(ctx: WorkflowActivityContext, payload: Any = None) -> Any:
152+
normalized = (
153+
normalize_input(activity_sig, payload) if payload is not None else {}
154+
)
155+
156+
if prompt_template:
157+
formatted_prompt = format_prompt(
158+
activity_sig, prompt_template, normalized
159+
)
160+
else:
161+
formatted_prompt = format_agent_input(payload, normalized)
162+
163+
raw = await agent.run(formatted_prompt)
164+
converted = convert_result(raw)
165+
validated = await validate_result(converted, activity_sig)
166+
return validated
167+
168+
@functools.wraps(func)
169+
def wrapper(*args: Any, **kwargs: Any) -> Any:
170+
"""Sync activity wrapper: execute async pipeline to completion."""
171+
ctx, payload = extract_ctx_and_payload(args, dict(kwargs))
172+
result = _execute(ctx, payload) # coroutine
173+
174+
try:
175+
loop = asyncio.get_running_loop()
176+
except RuntimeError:
177+
loop = None
178+
179+
if loop and loop.is_running():
180+
fut = asyncio.run_coroutine_threadsafe(result, loop)
181+
return fut.result()
182+
183+
return asyncio.run(result)
184+
185+
wrapper._is_agent_activity = True # noqa: SLF001
186+
wrapper._agent_activity_config = { # noqa: SLF001
187+
"prompt": prompt,
188+
"task_kwargs": task_kwargs,
189+
}
190+
wrapper._original_activity = func # noqa: SLF001
191+
return wrapper
192+
193+
return decorator

0 commit comments

Comments
 (0)