-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathworker.py
More file actions
66 lines (53 loc) · 1.77 KB
/
worker.py
File metadata and controls
66 lines (53 loc) · 1.77 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import asyncio
import json
from datetime import timedelta
import model
from temporalio.client import Client
from temporalio.worker import Worker
from temporalio import workflow
from temporalio import activity
from gen.proto.labels.v1.labels_pb2 import GetLabelsRequest, GetLabelsResponse
loaded_model = None
workflow_name = "get-labels-workflow"
activity_name = "get-labels-activity"
task_queue = "labels-tasks"
@activity.defn(name=activity_name)
async def GetLabelsActivity(issue: model.Issue) -> list[str]:
global loaded_model
if loaded_model is None:
with open('data.json') as file:
args=json.load(file)
loaded_model = model.load_model(
mlflow_server_uri=args['mlflow_server'],
model_run_id=args['model_run_id'],
model_name=args['model_name'],
model_version=args['model_version'],
)
return loaded_model.run(issue)
@workflow.defn(name=workflow_name, sandboxed=False)
class GetLabelsWorkflow:
@workflow.run
async def run(self, req: GetLabelsRequest) -> GetLabelsResponse:
issue = model.Issue(
title=req.title,
body=req.body,
labels=list(req.labels),
creator=req.creator,
)
suggested_labels = await workflow.execute_activity(
GetLabelsActivity,
issue,
start_to_close_timeout=timedelta(minutes=10),
)
return GetLabelsResponse(labels=suggested_labels)
async def main():
client = await Client.connect(target_host="localhost:7233")
worker = Worker(
client,
task_queue=task_queue,
workflows=[GetLabelsWorkflow],
activities=[GetLabelsActivity],
)
await worker.run()
if __name__ == "__main__":
asyncio.run(main())