Skip to content
This repository was archived by the owner on Mar 26, 2025. It is now read-only.

Commit 19eb988

Browse files
reillyseSean Reilly
andauthored
add a run_workflows method to run many workflows in bulk (#222)
* add a run_workflows method to run many workflows in bulk * lint * use the dedupe exceptions when we have a dedupe * add spawn_workflows and run_workflows with an example * think this is how the testing works * Add a test although it hangs on the second one - I think because of event loop stuff * is this how I bump version * maybe --------- Co-authored-by: Sean Reilly <[email protected]>
1 parent 8dd2efe commit 19eb988

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+2347
-146
lines changed
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import asyncio
2+
import base64
3+
import json
4+
import os
5+
6+
from dotenv import load_dotenv
7+
8+
from hatchet_sdk import new_client
9+
from hatchet_sdk.clients.admin import TriggerWorkflowOptions
10+
from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun
11+
from hatchet_sdk.clients.run_event_listener import StepRunEventType
12+
13+
14+
async def main():
15+
load_dotenv()
16+
hatchet = new_client()
17+
18+
workflowRuns: WorkflowRun = []
19+
20+
# we are going to run the BulkParent workflow 20 which will trigger the Child workflows n times for each n in range(20)
21+
for i in range(20):
22+
workflowRuns.append(
23+
{
24+
"workflow_name": "BulkParent",
25+
"input": {"n": i},
26+
"options": {
27+
"additional_metadata": {
28+
"bulk-trigger": i,
29+
"hello-{i}": "earth-{i}",
30+
},
31+
},
32+
}
33+
)
34+
35+
workflowRunRefs = hatchet.admin.run_workflows(
36+
workflowRuns,
37+
)
38+
39+
results = await asyncio.gather(
40+
*[workflowRunRef.result() for workflowRunRef in workflowRunRefs],
41+
return_exceptions=True,
42+
)
43+
44+
for result in results:
45+
if isinstance(result, Exception):
46+
print(f"An error occurred: {result}") # Handle the exception here
47+
else:
48+
print(result)
49+
50+
51+
if __name__ == "__main__":
52+
asyncio.run(main())

examples/bulk_fanout/stream.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import asyncio
2+
import base64
3+
import json
4+
import os
5+
import random
6+
7+
from dotenv import load_dotenv
8+
9+
from hatchet_sdk import new_client
10+
from hatchet_sdk.clients.admin import TriggerWorkflowOptions
11+
from hatchet_sdk.clients.run_event_listener import StepRunEventType
12+
from hatchet_sdk.v2.hatchet import Hatchet
13+
14+
15+
async def main():
16+
load_dotenv()
17+
hatchet = Hatchet()
18+
19+
# Generate a random stream key to use to track all
20+
# stream events for this workflow run.
21+
22+
streamKey = "streamKey"
23+
streamVal = f"sk-{random.randint(1, 100)}"
24+
25+
# Specify the stream key as additional metadata
26+
# when running the workflow.
27+
28+
# This key gets propagated to all child workflows
29+
# and can have an arbitrary property name.
30+
31+
workflowRun = hatchet.admin.run_workflow(
32+
"Parent",
33+
{"n": 2},
34+
options={"additional_metadata": {streamKey: streamVal}},
35+
)
36+
37+
# Stream all events for the additional meta key value
38+
listener = hatchet.listener.stream_by_additional_metadata(streamKey, streamVal)
39+
40+
async for event in listener:
41+
print(event.type, event.payload)
42+
43+
44+
if __name__ == "__main__":
45+
asyncio.run(main())
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import pytest
2+
3+
from hatchet_sdk import Hatchet
4+
from tests.utils import fixture_bg_worker
5+
from tests.utils.hatchet_client import hatchet_client_fixture
6+
7+
hatchet = hatchet_client_fixture()
8+
worker = fixture_bg_worker(["poetry", "run", "bulk_fanout"])
9+
10+
11+
# requires scope module or higher for shared event loop
12+
@pytest.mark.asyncio(scope="session")
13+
async def test_run(hatchet: Hatchet):
14+
run = hatchet.admin.run_workflow("BulkParent", {"n": 12})
15+
result = await run.result()
16+
print(result)
17+
assert len(result["spawn"]["results"]) == 12
18+
19+
20+
# requires scope module or higher for shared event loop
21+
@pytest.mark.asyncio(scope="session")
22+
async def test_run2(hatchet: Hatchet):
23+
run = hatchet.admin.run_workflow("BulkParent", {"n": 10})
24+
result = await run.result()
25+
assert len(result["spawn"]["results"]) == 10

examples/bulk_fanout/trigger.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import asyncio
2+
import base64
3+
import json
4+
import os
5+
6+
from dotenv import load_dotenv
7+
8+
from hatchet_sdk import new_client
9+
from hatchet_sdk.clients.admin import TriggerWorkflowOptions
10+
from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun
11+
from hatchet_sdk.clients.run_event_listener import StepRunEventType
12+
13+
14+
async def main():
15+
load_dotenv()
16+
hatchet = new_client()
17+
18+
workflowRuns: WorkflowRun = []
19+
20+
event = hatchet.event.push(
21+
"parent:create", {"n": 999}, {"additional_metadata": {"no-dedupe": "world"}}
22+
)
23+
24+
25+
if __name__ == "__main__":
26+
asyncio.run(main())

examples/bulk_fanout/worker.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import asyncio
2+
from typing import List
3+
4+
from dotenv import load_dotenv
5+
6+
from hatchet_sdk import Context, Hatchet
7+
from hatchet_sdk.clients.admin import ChildWorkflowRunDict
8+
9+
load_dotenv()
10+
11+
hatchet = Hatchet(debug=True)
12+
13+
14+
@hatchet.workflow(on_events=["parent:create"])
15+
class BulkParent:
16+
@hatchet.step(timeout="5m")
17+
async def spawn(self, context: Context):
18+
print("spawning child")
19+
20+
context.put_stream("spawning...")
21+
results = []
22+
23+
n = context.workflow_input().get("n", 100)
24+
25+
child_workflow_runs: List[ChildWorkflowRunDict] = []
26+
27+
for i in range(n):
28+
29+
child_workflow_runs.append(
30+
{
31+
"workflow_name": "BulkChild",
32+
"input": {"a": str(i)},
33+
"key": f"child{i}",
34+
"options": {"additional_metadata": {"hello": "earth"}},
35+
}
36+
)
37+
38+
if len(child_workflow_runs) == 0:
39+
return
40+
41+
spawn_results = await context.aio.spawn_workflows(child_workflow_runs)
42+
43+
results = await asyncio.gather(
44+
*[workflowRunRef.result() for workflowRunRef in spawn_results],
45+
return_exceptions=True,
46+
)
47+
48+
print("finished spawning children")
49+
50+
for result in results:
51+
if isinstance(result, Exception):
52+
print(f"An error occurred: {result}")
53+
else:
54+
print(result)
55+
56+
return {"results": results}
57+
58+
59+
@hatchet.workflow(on_events=["child:create"])
60+
class BulkChild:
61+
@hatchet.step()
62+
def process(self, context: Context):
63+
a = context.workflow_input()["a"]
64+
print(f"child process {a}")
65+
context.put_stream("child 1...")
66+
return {"status": "success " + a}
67+
68+
@hatchet.step()
69+
def process2(self, context: Context):
70+
print("child process2")
71+
context.put_stream("child 2...")
72+
return {"status2": "success"}
73+
74+
75+
def main():
76+
77+
worker = hatchet.worker("fanout-worker", max_runs=40)
78+
worker.register_workflow(BulkParent())
79+
worker.register_workflow(BulkChild())
80+
worker.start()
81+
82+
83+
if __name__ == "__main__":
84+
main()

hatchet_sdk/clients/admin.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,14 @@
55
import grpc
66
from google.protobuf import timestamp_pb2
77

8+
from hatchet_sdk.clients.rest.models.workflow_run import WorkflowRun
89
from hatchet_sdk.clients.rest.tenacity_utils import tenacity_retry
910
from hatchet_sdk.clients.run_event_listener import new_listener
1011
from hatchet_sdk.clients.workflow_listener import PooledWorkflowRunListener
1112
from hatchet_sdk.connection import new_conn
1213
from hatchet_sdk.contracts.workflows_pb2 import (
14+
BulkTriggerWorkflowRequest,
15+
BulkTriggerWorkflowResponse,
1316
CreateWorkflowVersionOpts,
1417
PutRateLimitRequest,
1518
PutWorkflowRequest,
@@ -44,6 +47,19 @@ class ChildTriggerWorkflowOptions(TypedDict):
4447
sticky: bool | None = None
4548

4649

50+
class WorkflowRunDict(TypedDict):
51+
workflow_name: str
52+
input: Any
53+
options: Optional[dict]
54+
55+
56+
class ChildWorkflowRunDict(TypedDict):
57+
workflow_name: str
58+
input: Any
59+
options: ChildTriggerWorkflowOptions[dict]
60+
key: str
61+
62+
4763
class TriggerWorkflowOptions(ScheduleTriggerWorkflowOptions, TypedDict):
4864
additional_metadata: Dict[str, str] | None = None
4965
desired_worker_id: str | None = None
@@ -203,6 +219,65 @@ async def run_workflow(
203219

204220
raise ValueError(f"gRPC error: {e}")
205221

222+
@tenacity_retry
223+
async def run_workflows(
224+
self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None
225+
) -> List[WorkflowRunRef]:
226+
227+
if len(workflows) == 0:
228+
raise ValueError("No workflows to run")
229+
try:
230+
if not self.pooled_workflow_listener:
231+
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
232+
233+
namespace = self.namespace
234+
235+
if (
236+
options is not None
237+
and "namespace" in options
238+
and options["namespace"] is not None
239+
):
240+
namespace = options["namespace"]
241+
del options["namespace"]
242+
243+
workflow_run_requests: TriggerWorkflowRequest = []
244+
245+
for workflow in workflows:
246+
247+
workflow_name = workflow["workflow_name"]
248+
input_data = workflow["input"]
249+
options = workflow["options"]
250+
251+
if namespace != "" and not workflow_name.startswith(self.namespace):
252+
workflow_name = f"{namespace}{workflow_name}"
253+
254+
# Prepare and trigger workflow for each workflow name and input
255+
request = self._prepare_workflow_request(
256+
workflow_name, input_data, options
257+
)
258+
workflow_run_requests.append(request)
259+
260+
request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests)
261+
262+
resp: BulkTriggerWorkflowResponse = (
263+
await self.aio_client.BulkTriggerWorkflow(
264+
request,
265+
metadata=get_metadata(self.token),
266+
)
267+
)
268+
269+
return [
270+
WorkflowRunRef(
271+
workflow_run_id=workflow_run_id,
272+
workflow_listener=self.pooled_workflow_listener,
273+
workflow_run_event_listener=self.listener_client,
274+
)
275+
for workflow_run_id in resp.workflow_run_ids
276+
]
277+
278+
except grpc.RpcError as e:
279+
raise ValueError(f"gRPC error: {e}")
280+
206281
@tenacity_retry
207282
async def put_workflow(
208283
self,
@@ -398,6 +473,61 @@ def run_workflow(
398473

399474
raise ValueError(f"gRPC error: {e}")
400475

476+
@tenacity_retry
477+
def run_workflows(
478+
self, workflows: List[WorkflowRunDict], options: TriggerWorkflowOptions = None
479+
) -> list[WorkflowRunRef]:
480+
481+
workflow_run_requests: TriggerWorkflowRequest = []
482+
try:
483+
if not self.pooled_workflow_listener:
484+
self.pooled_workflow_listener = PooledWorkflowRunListener(self.config)
485+
486+
for workflow in workflows:
487+
488+
workflow_name = workflow["workflow_name"]
489+
input_data = workflow["input"]
490+
options = workflow["options"]
491+
492+
namespace = self.namespace
493+
494+
if (
495+
options is not None
496+
and "namespace" in options
497+
and options["namespace"] is not None
498+
):
499+
namespace = options["namespace"]
500+
del options["namespace"]
501+
502+
if namespace != "" and not workflow_name.startswith(self.namespace):
503+
workflow_name = f"{namespace}{workflow_name}"
504+
505+
# Prepare and trigger workflow for each workflow name and input
506+
request = self._prepare_workflow_request(
507+
workflow_name, input_data, options
508+
)
509+
510+
workflow_run_requests.append(request)
511+
512+
request = BulkTriggerWorkflowRequest(workflows=workflow_run_requests)
513+
514+
resp: BulkTriggerWorkflowResponse = self.client.BulkTriggerWorkflow(
515+
request,
516+
metadata=get_metadata(self.token),
517+
)
518+
519+
except grpc.RpcError as e:
520+
raise ValueError(f"gRPC error: {e}")
521+
522+
return [
523+
WorkflowRunRef(
524+
workflow_run_id=workflow_run_id,
525+
workflow_listener=self.pooled_workflow_listener,
526+
workflow_run_event_listener=self.listener_client,
527+
)
528+
for workflow_run_id in resp.workflow_run_ids
529+
]
530+
401531
def run(
402532
self,
403533
function: Union[str, Callable[[Any], T]],

0 commit comments

Comments
 (0)