Skip to content

Commit 35d615b

Browse files
authored
feat: add workflows PC-398 (#341)
* chore: add workflow models and serializers * feat(workflows): add list and runs command * feat(workflows): workflows and workflow runs get and list command * feat(workflows): create, run commands * feat(workflows): create run and logs command * fix: update workflowId to id * using --show-runs flag * fix logs command
1 parent 91cb0f7 commit 35d615b

File tree

12 files changed

+703
-0
lines changed

12 files changed

+703
-0
lines changed

gradient/api_sdk/clients/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,4 @@
1515
from .storage_provider_client import StorageProvidersClient
1616
from .sdk_client import SdkClient
1717
from .tensorboards_client import TensorboardClient
18+
from .workflow_client import WorkflowsClient

gradient/api_sdk/clients/sdk_client.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from . import DeploymentsClient, ExperimentsClient, HyperparameterJobsClient, ModelsClient, ProjectsClient, \
22
MachinesClient, NotebooksClient, SecretsClient
33
from .job_client import JobsClient
4+
from .workflow_client import WorkflowsClient
45
from .. import logger as sdk_logger
56

67

@@ -19,3 +20,4 @@ def __init__(self, api_key, logger=sdk_logger.MuteLogger()):
1920
self.machines = MachinesClient(api_key=api_key, logger=logger)
2021
self.notebooks = NotebooksClient(api_key=api_key, logger=logger)
2122
self.secrets = SecretsClient(api_key=api_key, logger=logger)
23+
self.workflows = WorkflowsClient(api_key=api_key, logger=logger)
Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from .base_client import BaseClient
2+
from .. import models, repositories
3+
from ...exceptions import ReceivingDataFailedError
4+
5+
6+
class WorkflowsClient(BaseClient):
7+
8+
def create(self, name, project_id):
9+
"""Create workflow with spec
10+
11+
:param str name: workflow name
12+
:param str project_id: project id
13+
14+
:returns: workflow create response
15+
:rtype: list[models.Workflow]
16+
"""
17+
18+
repository = self.build_repository(repositories.CreateWorkflow)
19+
workflow = repository.create(name=name, project_id=project_id)
20+
return workflow
21+
22+
def run_workflow(self, spec, inputs, workflow_id, cluster_id):
23+
"""Create workflow with spec
24+
25+
:param obj spec: workflow spec
26+
:param obj inputs: workflow inputs
27+
:param str workflow_id: workflow id
28+
:param str cluster_id: cluster id
29+
30+
:returns: workflow create response
31+
:rtype: list[models.Workflow]
32+
"""
33+
34+
repository = self.build_repository(repositories.CreateWorkflowRun)
35+
workflow = repository.create(spec=spec, inputs=inputs, id=workflow_id, cluster_id=cluster_id)
36+
return workflow
37+
38+
def list(self, project_id):
39+
"""List workflows by project
40+
41+
:param str project_id: project ID
42+
43+
:returns: list of workflows
44+
:rtype: list[models.Workflow]
45+
"""
46+
47+
repository = self.build_repository(repositories.ListWorkflows)
48+
workflows = repository.list(project_id=project_id)
49+
return workflows
50+
51+
def get(self, workflow_id):
52+
"""Get a Workflow
53+
54+
:param str workflow_id: Workflow ID [required]
55+
56+
:returns: workflow
57+
:rtype: models.Workflow
58+
"""
59+
repository = self.build_repository(repositories.GetWorkflow)
60+
return repository.get(id=workflow_id)
61+
62+
63+
def list_runs(self, workflow_id):
64+
"""List workflows runs by workflow id
65+
66+
:param str workflow_id: workflow ID
67+
68+
:returns: list of workflow runs
69+
"""
70+
71+
repository = self.build_repository(repositories.ListWorkflowRuns)
72+
workflows_runs = repository.get(id=workflow_id)
73+
return workflows_runs
74+
75+
def get_run(self, workflow_id, run):
76+
"""List workflows runs by workflow id
77+
78+
:param str workflow_id: workflow ID
79+
:param str run: run count
80+
81+
:returns: list of workflow runs
82+
"""
83+
84+
repository = self.build_repository(repositories.GetWorkflowRun)
85+
workflows_runs = repository.get(id=workflow_id, run=run)
86+
return workflows_runs
87+
88+
def yield_logs(self, job_id, line=1, limit=10000):
89+
"""Get log generator. Polls the API for new logs
90+
91+
.. code-block:: python
92+
:linenos:
93+
:emphasize-lines: 2
94+
95+
job_logs_generator = job_client.yield_logs(
96+
job_id='Your_job_id_here',
97+
line=100,
98+
limit=100
99+
)
100+
101+
:param str job_id:
102+
:param int line: line number at which logs starts to display on screen
103+
:param int limit: maximum lines displayed on screen, default set to 10 000
104+
105+
:returns: generator yielding LogRow instances
106+
:rtype: Iterator[models.LogRow]
107+
"""
108+
109+
repository = self.build_repository(repositories.ListWorkflowLogs)
110+
logs = repository.yield_logs(id=job_id, line=line, limit=limit)
111+
return logs
112+
113+
def logs(self, job_id, line=1, limit=10000):
114+
"""Get log generator. Polls the API for new logs
115+
116+
.. code-block:: python
117+
:linenos:
118+
:emphasize-lines: 2
119+
120+
job_logs_generator = job_client.yield_logs(
121+
job_id='Your_job_id_here',
122+
line=100,
123+
limit=100
124+
)
125+
126+
:param str job_id:
127+
:param int line: line number at which logs starts to display on screen
128+
:param int limit: maximum lines displayed on screen, default set to 10 000
129+
130+
:returns: generator yielding LogRow instances
131+
:rtype: Iterator[models.LogRow]
132+
"""
133+
134+
repository = self.build_repository(repositories.ListWorkflowLogs)
135+
logs = repository.list(id=job_id, line=line, limit=limit)
136+
return logs

gradient/api_sdk/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
from .tag import Tag
2121
from .tensorboard import Instance, Tensorboard
2222
from .vm_type import VmType, VmTypeGpuModel
23+
from .workflows import Workflow, WorkflowRun, WorkflowSpec

gradient/api_sdk/models/workflows.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import datetime
2+
3+
import attr
4+
5+
6+
@attr.s
7+
class Workflow(object):
8+
id = attr.ib(type=str, default=None)
9+
team_id = attr.ib(type=int, default=None)
10+
project_id = attr.ib(type=int, default=None)
11+
name = attr.ib(type=str, default=None)
12+
workflow_spec_id = attr.ib(type=str, default=None)
13+
dt_deleted = attr.ib(type=datetime.datetime, default=None)
14+
dt_created = attr.ib(type=datetime.datetime, default=None)
15+
dt_modified = attr.ib(type=datetime.datetime, default=None)
16+
17+
@attr.s
18+
class WorkflowSpec(object):
19+
id = attr.ib(type=str, default=None)
20+
data = attr.ib(type=str, default=None)
21+
hash_sha256 = attr.ib(type=str, default=None)
22+
dt_created = attr.ib(type=datetime.datetime, default=None)
23+
24+
@attr.s
25+
class WorkflowRun(object):
26+
id = attr.ib(type=str, default=None)
27+
team_id = attr.ib(type=int, default=None)
28+
workflow_id = attr.ib(type=str, default=None)
29+
cluster_id = attr.ib(type=int, default=None)
30+
user_id = attr.ib(type=int, default=None)
31+
workflow_spec_id = attr.ib(type=str, default=None)
32+
seq_num = attr.ib(type=int, default=None)
33+
timeout = attr.ib(type=int, default=None)
34+
workflow_phase_id = attr.ib(type=int, default=None)
35+
name = attr.ib(type=str, default=None)
36+
message = attr.ib(type=str, default=None)
37+
dt_status = attr.ib(type=datetime.datetime, default=None)
38+
dt_started = attr.ib(type=datetime.datetime, default=None)
39+
dt_finished = attr.ib(type=datetime.datetime, default=None)
40+
dt_created = attr.ib(type=datetime.datetime, default=None)
41+
dt_modified = attr.ib(type=datetime.datetime, default=None)

gradient/api_sdk/repositories/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,4 @@
2323
from .storage_providers import ListStorageProviders, CreateStorageProvider, DeleteStorageProvider, \
2424
GetStorageProvider, UpdateStorageProvider
2525
from .tensorboards import CreateTensorboard, GetTensorboard, ListTensorboards, UpdateTensorboard, DeleteTensorboard
26+
from .workflows import ListWorkflows, GetWorkflow, ListWorkflowRuns, GetWorkflowRun, CreateWorkflow, CreateWorkflowRun, ListWorkflowLogs
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
from .common import BaseRepository, ListResources, GetResource, ListLogs
2+
from .. import config, serializers
3+
from ..clients import http_client
4+
import json
5+
6+
class WorkflowsMixin(object):
7+
SERIALIZER_CLS = serializers.WorkflowSchema
8+
9+
@staticmethod
10+
def _get_api_url(**kwargs):
11+
return config.config.CONFIG_HOST
12+
13+
14+
class ListWorkflows(WorkflowsMixin, ListResources):
15+
def get_request_url(self, **kwargs):
16+
project_id = kwargs.get("project_id")
17+
if project_id is not None:
18+
return "/workflows?filter[where][projectId]={}".format(project_id)
19+
return "/workflows"
20+
21+
def _get_instances(self, response, **kwargs):
22+
if not response.data:
23+
return []
24+
25+
objects = self._parse_objects(response.data, **kwargs)
26+
return objects
27+
28+
class GetWorkflow(WorkflowsMixin, BaseRepository):
29+
def get_request_url(self, **kwargs):
30+
return "/workflows/{}".format(kwargs.get("id"))
31+
32+
def get(self, **kwargs):
33+
json_ = self._get_request_json(kwargs)
34+
params = self._get_request_params(kwargs)
35+
url = self.get_request_url(**kwargs)
36+
client = self._get_client(**kwargs)
37+
response = self._send_request(client, url, json=json_, params=params)
38+
gradient_response = http_client.GradientResponse.interpret_response(response)
39+
40+
if not gradient_response.data:
41+
return {}
42+
43+
return gradient_response.data
44+
45+
46+
class WorkflowRunsMixin(object):
47+
@staticmethod
48+
def _get_api_url(**kwargs):
49+
return config.config.CONFIG_HOST
50+
51+
class ListWorkflowRuns(WorkflowRunsMixin, BaseRepository):
52+
@staticmethod
53+
def _get_api_url(**kwargs):
54+
return config.config.CONFIG_HOST
55+
56+
57+
def get_request_url(self, **kwargs):
58+
return "/workflows/{}/runs".format(kwargs.get("id"))
59+
60+
def get(self, **kwargs):
61+
json_ = self._get_request_json(kwargs)
62+
params = self._get_request_params(kwargs)
63+
url = self.get_request_url(**kwargs)
64+
client = self._get_client(**kwargs)
65+
response = self._send_request(client, url, json=json_, params=params)
66+
gradient_response = http_client.GradientResponse.interpret_response(response)
67+
68+
if not gradient_response.data:
69+
return []
70+
71+
return gradient_response.data
72+
73+
class GetWorkflowRun(WorkflowRunsMixin, BaseRepository):
74+
def get_request_url(self, **kwargs):
75+
return "/workflows/{}/runs/{}".format(kwargs.get("id"), kwargs.get("run"))
76+
77+
def get(self, **kwargs):
78+
json_ = self._get_request_json(kwargs)
79+
params = self._get_request_params(kwargs)
80+
url = self.get_request_url(**kwargs)
81+
client = self._get_client(**kwargs)
82+
response = self._send_request(client, url, json=json_, params=params)
83+
gradient_response = http_client.GradientResponse.interpret_response(response)
84+
85+
if not gradient_response.data:
86+
return {}
87+
88+
return gradient_response.data
89+
90+
91+
class CreateWorkflow(WorkflowsMixin, BaseRepository):
92+
def get_request_url(self, **kwargs):
93+
return "/workflows"
94+
95+
def _get_request_json(self, kwargs):
96+
return {"name": kwargs.get("name"), "projectId": kwargs.get("project_id")}
97+
98+
def _send_request(self, client, url, json=None, params=None):
99+
response = client.post(url, json=json, params=params)
100+
return response
101+
102+
def create(self, **kwargs):
103+
response = self._get(**kwargs)
104+
self._validate_response(response)
105+
106+
if not response.data:
107+
return {}
108+
109+
return response.data
110+
111+
class CreateWorkflowRun(WorkflowsMixin, BaseRepository):
112+
def get_request_url(self, **kwargs):
113+
return "/workflows/{}/runs".format(kwargs.get("id"))
114+
115+
def _get_request_json(self, **kwargs):
116+
if kwargs.get("inputs") is not None:
117+
return {"spec": kwargs.get("spec"), "clusterId": kwargs.get("cluster_id"), "run": True, "markDefault": False, "inputs": kwargs.get("inputs") }
118+
119+
return {"spec": kwargs.get("spec"), "clusterId": kwargs.get("cluster_id"), "run": True, "markDefault": False }
120+
121+
def _send_create_request(self, **kwargs):
122+
url = self.get_request_url(**kwargs)
123+
client = self._get_client(**kwargs)
124+
json_ = self._get_request_json(**kwargs)
125+
126+
response = client.post(url, json=json_)
127+
gradient_response = http_client.GradientResponse.interpret_response(response)
128+
129+
json_formatted_str = json.dumps(gradient_response.data, indent=4)
130+
return gradient_response
131+
132+
def create(self, **kwargs):
133+
response = self._send_create_request(**kwargs)
134+
self._validate_response(response)
135+
136+
if not response.data:
137+
return {}
138+
139+
return response.data
140+
141+
class ListWorkflowLogs(ListLogs):
142+
def _get_request_params(self, kwargs):
143+
params = {
144+
"jobId": kwargs["id"],
145+
"line": kwargs["line"],
146+
"limit": kwargs["limit"],
147+
}
148+
return params

gradient/api_sdk/serializers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,4 @@
1717
from .tag import TagSchema
1818
from .tensorboard import InstanceSchema, TensorboardSchema, TensorboardDetailSchema
1919
from .vm_type import VmTypeSchema, VmTypeGpuModelSchema
20+
from .workflows import WorkflowSchema, WorkflowRunSchema, WorkflowSpecSchema

0 commit comments

Comments
 (0)