Skip to content

Commit 815f04a

Browse files
committed
[Temporal - Documentation] Fixed ruff issues
1 parent b5f3586 commit 815f04a

File tree

1 file changed

+103
-91
lines changed

1 file changed

+103
-91
lines changed

docs/durable_execution/temporal.md

Lines changed: 103 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -205,10 +205,13 @@ Assuming your project has the following structure:
205205
206206
```
207207

208-
```py {title="utils.py" test="skip" noqa="F841"}
209-
import yaml
208+
```py {title="utils.py" test="skip"}
209+
import os
210210
from copy import copy
211211

212+
import yaml
213+
214+
212215
def recursively_modify_api_key(conf):
213216
"""
214217
Recursively replace API key placeholders with environment variable values.
@@ -243,6 +246,7 @@ def recursively_modify_api_key(conf):
243246
inner(copy_conf)
244247
return copy_conf
245248

249+
246250
def read_config_yml(path):
247251
"""
248252
Read and process a YAML configuration file.
@@ -256,21 +260,23 @@ def read_config_yml(path):
256260
Returns:
257261
dict: The parsed and processed YAML content as a Python dictionary.
258262
"""
259-
with open(path, "r") as f:
263+
with open(path) as f:
260264
configs = yaml.safe_load(f)
261265
recursively_modify_api_key(configs)
262266
return configs
263267
```
264268

265-
```py {title="datamodels.py" test="skip" noqa="F841"}
269+
```py {title="datamodels.py" test="skip"}
266270
from enum import Enum
267-
from typing import Any, Dict, Deque, AsyncIterable, Optional
271+
268272
from pydantic import BaseModel
269273

274+
270275
class AgentDependencies(BaseModel):
271276
workflow_id: str
272277
run_id: str
273278

279+
274280
class EventKind(str, Enum):
275281
CONTINUE_CHAT = 'continue_chat'
276282
EVENT = 'event'
@@ -282,26 +288,25 @@ class EventStream(BaseModel):
282288
content: str
283289
```
284290

285-
```py {title="agents.py" test="skip" noqa="F841"}
286-
from temporalio import workflow
291+
292+
```py {title="agents.py" test="skip"}
293+
from datetime import timedelta
294+
295+
from mcp_run_python import code_sandbox
296+
from pydantic_ai import Agent, FilteredToolset, ModelSettings, RunContext
297+
from pydantic_ai.durable_exec.temporal import TemporalAgent
298+
from pydantic_ai.mcp import MCPServerStdio
299+
from pydantic_ai.models.anthropic import AnthropicModel
300+
from pydantic_ai.providers.anthropic import AnthropicProvider
287301
from temporalio.common import RetryPolicy
288302
from temporalio.workflow import ActivityConfig
289-
with workflow.unsafe.imports_passed_through():
290-
from pydantic_ai import Agent, FilteredToolset, RunContext, ModelSettings
291-
from pydantic_ai.models.anthropic import AnthropicModel
292-
from pydantic_ai.providers.anthropic import AnthropicProvider
293-
from pydantic_ai.mcp import MCPServerStdio
294-
from pydantic_ai.durable_exec.temporal import TemporalAgent
295-
from datamodels import AgentDependencies
296-
from mcp_run_python import code_sandbox
297-
from typing import Dict
298-
from datetime import timedelta
299-
300-
301-
async def get_mcp_toolsets() -> Dict[str, FilteredToolset]:
303+
304+
from .datamodels import AgentDependencies
305+
306+
async def get_mcp_toolsets() -> dict[str, FilteredToolset]:
302307
yf_server = MCPServerStdio(
303-
command="uvx",
304-
args=["mcp-yahoo-finance"],
308+
command='uvx',
309+
args=['mcp-yahoo-finance'],
305310
timeout=240,
306311
read_timeout=240,
307312
id='yahoo'
@@ -310,76 +315,87 @@ async def get_mcp_toolsets() -> Dict[str, FilteredToolset]:
310315
'yahoo': yf_server.filtered(lambda ctx, tool_def: True)
311316
}
312317

318+
313319
async def get_claude_model(parallel_tool_calls: bool = True, **env_vars):
314320
model_name = 'claude-sonnet-4-5-20250929'
315321
api_key = env_vars.get('anthropic_api_key')
316322
model = AnthropicModel(model_name=model_name,
317323
provider=AnthropicProvider(api_key=api_key),
318324
settings=ModelSettings(**{
319-
"temperature": 0.5,
320-
"n": 1,
321-
"max_completion_tokens": 64000,
322-
"max_tokens": 64000,
323-
"parallel_tool_calls": parallel_tool_calls,
325+
'temperature': 0.5,
326+
'n': 1,
327+
'max_completion_tokens': 64000,
328+
'max_tokens': 64000,
329+
'parallel_tool_calls': parallel_tool_calls,
324330
}))
325331

326332
return model
327333

334+
328335
async def build_agent(stream_handler=None, **env_vars):
329-
330336
system_prompt = """
331337
You are an expert travel agent that knows perfectly how to search for hotels on the web.
332338
You also have a Data Analyst background, mastering well how to use pandas for tabular operations.
333339
"""
334-
agent_name = "YahooFinanceSearchAgent"
335-
340+
agent_name = 'YahooFinanceSearchAgent'
341+
336342
toolsets = await get_mcp_toolsets()
337343
agent = Agent(name=agent_name,
338-
model=await get_claude_model(**env_vars), # Here you place your Model instance
344+
model=await get_claude_model(**env_vars), # Here you place your Model instance
339345
toolsets=[*toolsets.values()],
340346
system_prompt=system_prompt,
341347
event_stream_handler=stream_handler,
342348
deps_type=AgentDependencies,
343349
)
344-
350+
345351
@agent.tool(name='run_python_code')
346352
async def run_python_code(ctx: RunContext[None], code: str) -> str:
347353
async with code_sandbox(dependencies=['pandas', 'numpy']) as sandbox:
348354
result = await sandbox.eval(code)
349355
return result
350-
351-
356+
352357
temporal_agent = TemporalAgent(wrapped=agent,
353-
model_activity_config=ActivityConfig(
354-
start_to_close_timeout=timedelta(minutes=5),
355-
retry_policy=RetryPolicy(maximum_attempts=50)
356-
),
357-
toolset_activity_config={
358-
toolset_id: ActivityConfig(
359-
start_to_close_timeout=timedelta(minutes=3),
360-
retry_policy=RetryPolicy(maximum_attempts=3,
361-
non_retryable_error_types=['ToolRetryError']
362-
)
363-
) for toolset_id in toolsets.keys()})
358+
model_activity_config=ActivityConfig(
359+
start_to_close_timeout=timedelta(minutes=5),
360+
retry_policy=RetryPolicy(maximum_attempts=50)
361+
),
362+
toolset_activity_config={
363+
toolset_id: ActivityConfig(
364+
start_to_close_timeout=timedelta(minutes=3),
365+
retry_policy=RetryPolicy(maximum_attempts=3,
366+
non_retryable_error_types=['ToolRetryError']
367+
)
368+
) for toolset_id in toolsets.keys()})
364369
return temporal_agent
365370
```
366371

367-
```py {title="streaming_handler.py" test="skip" noqa="F841"}
372+
```py {title="streaming_handler.py" test="skip"}
373+
from collections.abc import AsyncIterable
374+
375+
from .datamodels import AgentDependencies, EventKind, EventStream
368376
from temporalio import activity
369-
from typing import Any, Dict, Deque, AsyncIterable, Optional
370-
from pydantic_ai import AgentStreamEvent, FunctionToolCallEvent, \
371-
PartStartEvent, FunctionToolResultEvent, TextPart, ToolCallPart, PartDeltaEvent, UsageLimits, TextPartDelta, \
372-
ThinkingPartDelta
373-
from datamodels import EventStream, EventKind, AgentDependencies
377+
378+
from pydantic_ai import (
379+
AgentStreamEvent,
380+
FunctionToolCallEvent,
381+
PartStartEvent,
382+
FunctionToolResultEvent,
383+
TextPart,
384+
ToolCallPart,
385+
PartDeltaEvent,
386+
TextPartDelta,
387+
ThinkingPartDelta,
388+
)
389+
374390

375391
async def streaming_handler(ctx,
376392
event_stream_events: AsyncIterable[AgentStreamEvent]):
377393
"""
378394
This function is used by the agent to stream-out the actions that are being performed (tool calls, llm call, streaming results, etc etc.
379395
Feel free to change it as you like or need - skipping events or enriching the content
380396
"""
381-
382-
output = ""
397+
398+
output = ''
383399
output_tool_delta = dict(
384400
tool_call_id='',
385401
tool_name_delta='',
@@ -389,34 +405,34 @@ async def streaming_handler(ctx,
389405
async for event in event_stream_events:
390406
if isinstance(event, PartStartEvent):
391407
if isinstance(event.part, TextPart):
392-
output += f"{event.part.content}"
408+
output += f'{event.part.content}'
393409
elif isinstance(event.part, ToolCallPart):
394-
output += f"\nTool Call Id: {event.part.tool_call_id}"
395-
output += f"\nTool Name: {event.part.tool_name}"
396-
output += f"\nTool Args: {event.part.args}"
410+
output += f'\nTool Call Id: {event.part.tool_call_id}'
411+
output += f'\nTool Name: {event.part.tool_name}'
412+
output += f'\nTool Args: {event.part.args}'
397413
else:
398414
pass
399415
elif isinstance(event, FunctionToolCallEvent):
400-
output += f"\nTool Call Id: {event.part.tool_call_id}"
401-
output += f"\nTool Name: {event.part.tool_name}"
402-
output += f"\nTool Args: {event.part.args}"
416+
output += f'\nTool Call Id: {event.part.tool_call_id}'
417+
output += f'\nTool Name: {event.part.tool_name}'
418+
output += f'\nTool Args: {event.part.args}'
403419
elif isinstance(event, FunctionToolResultEvent):
404-
output += f"\nTool Call Id: {event.result.tool_call_id}"
405-
output += f"\nTool Name: {event.result.tool_name}"
406-
output += f"\nContent: {event.result.content}"
420+
output += f'\nTool Call Id: {event.result.tool_call_id}'
421+
output += f'\nTool Name: {event.result.tool_name}'
422+
output += f'\nContent: {event.result.content}'
407423
elif isinstance(event, PartDeltaEvent):
408424
if isinstance(event.delta, TextPartDelta) or isinstance(event.delta, ThinkingPartDelta):
409-
output += f"{event.delta.content_delta}"
425+
output += f'{event.delta.content_delta}'
410426
else:
411427
if len(output_tool_delta['tool_call_id']) == 0:
412428
output_tool_delta['tool_call_id'] += event.delta.tool_call_id or ''
413429
output_tool_delta['tool_name_delta'] += event.delta.tool_name_delta or ''
414430
output_tool_delta['args_delta'] += event.delta.args_delta or ''
415431

416432
if len(output_tool_delta['tool_call_id']):
417-
output += f"\nTool Call Id: {output_tool_delta['tool_call_id']}"
418-
output += f"\nTool Name: {output_tool_delta['tool_name_delta']}"
419-
output += f"\nTool Args: {output_tool_delta['args_delta']}"
433+
output += f'\nTool Call Id: {output_tool_delta["tool_call_id"]}'
434+
output += f'\nTool Name: {output_tool_delta["tool_name_delta"]}'
435+
output += f'\nTool Args: {output_tool_delta["args_delta"]}'
420436

421437
events = []
422438

@@ -434,36 +450,31 @@ async def streaming_handler(ctx,
434450
await workflow_handle.signal('append_event', arg=event)
435451
```
436452

437-
438-
```py {title="workflow.py" test="skip" noqa="F841"}
453+
```py {title="workflow.py" test="skip"}
439454

440455
import asyncio
441456
from collections import deque
442457
from datetime import timedelta
443-
from typing import Any, Dict, Deque, Optional
458+
from typing import Any
444459

445-
from temporalio import workflow, activity
446-
447-
with workflow.unsafe.imports_passed_through():
448-
from datamodels import EventStream, EventKind, AgentDependencies
449-
from agents import YahooFinanceSearchAgent
450-
from pydanticai import UsageLimits
451-
from agents import streaming_handler, build_agent
452-
from utils import read_config_yml
460+
from pydanticai import UsageLimits
461+
from temporalio import activity, workflow
453462

463+
from .agents import build_agent, streaming_handler
464+
from .datamodels import AgentDependencies, EventKind, EventStream
454465

455466
@workflow.defn
456467
class YahooFinanceSearchWorkflow:
457468
def __init__(self):
458-
self.events: Deque[EventStream] = deque()
469+
self.events: deque[EventStream] = deque()
459470

460471
@workflow.run
461472
async def run(self, user_prompt: str):
462473

463474
wf_vars = await workflow.execute_activity(
464475
activity='retrieve_env_vars',
465476
start_to_close_timeout=timedelta(seconds=10),
466-
result_type=Dict[str, Any],
477+
result_type=dict[str, Any],
467478
)
468479
deps = AgentDependencies(workflow_id=workflow.info().workflow_id, run_id=workflow.info().run_id)
469480

@@ -472,12 +483,12 @@ class YahooFinanceSearchWorkflow:
472483
usage_limits=UsageLimits(request_limit=50),
473484
deps=deps
474485
)
475-
486+
476487
await self.append_event(event_stream=EventStream(kind=EventKind.RESULT,
477-
content=result.output))
488+
content=result.output))
478489

479490
await self.append_event(event_stream=EventStream(kind=EventKind.CONTINUE_CHAT,
480-
content=""))
491+
content=''))
481492

482493
try:
483494
await workflow.wait_condition(
@@ -492,16 +503,17 @@ class YahooFinanceSearchWorkflow:
492503
@staticmethod
493504
@activity.defn(name='retrieve_env_vars')
494505
async def retrieve_env_vars():
495-
with workflow.unsafe.imports_passed_through():
496-
import os
497-
config_path = os.getenv('APP_CONFIG_PATH', './app_conf.yml')
498-
configs = read_config_yml(config_path)
499-
return {
500-
'anthropic_api_key': configs['llm']['anthropic_api_key']
501-
}
506+
import os
507+
from .utils import read_config_yml
508+
509+
config_path = os.getenv('APP_CONFIG_PATH', './app_conf.yml')
510+
configs = read_config_yml(config_path)
511+
return {
512+
'anthropic_api_key': configs['llm']['anthropic_api_key']
513+
}
502514

503515
@workflow.query
504-
def event_stream(self) -> Optional[EventStream]:
516+
def event_stream(self) -> EventStream | None:
505517
if self.events:
506518
return self.events.popleft()
507519
return None

0 commit comments

Comments
 (0)