Skip to content

Commit cd9e4c5

Browse files
author
Roger Lam
authored
feat(models): add model create command (#347)
This adds the ability to create a model with a dataset ref.
1 parent de05f11 commit cd9e4c5

File tree

11 files changed

+256
-2
lines changed

11 files changed

+256
-2
lines changed

.circleci/config.yml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ version: 2.1
22

33
orbs:
44
release-tools: paperspace/[email protected]
5+
docker-tools: paperspace/[email protected]
6+
7+
_docker_image: &docker_image paperspace/gradient-sdk
8+
_workspace_root: &workspace_root .
59

610
workflows:
711
master:
@@ -21,13 +25,39 @@ workflows:
2125
filters:
2226
branches:
2327
only: master
28+
- docker-tools/build_and_push:
29+
name: build_and_push_master
30+
context: docker-deploy
31+
docker_username: ${DOCKER_USERNAME}
32+
docker_password: ${DOCKER_PASSWORD}
33+
workspace_root: *workspace_root
34+
docker_image: *docker_image
35+
docker_tag: 0.0.0-latest
36+
requires:
37+
- test
38+
filters:
39+
branches:
40+
only: master
2441

2542
pr:
2643
jobs:
2744
- test:
2845
filters:
2946
branches:
3047
ignore: master
48+
- docker-tools/build_and_push:
49+
name: build_and_push
50+
context: docker-deploy
51+
docker_username: ${DOCKER_USERNAME}
52+
docker_password: ${DOCKER_PASSWORD}
53+
workspace_root: *workspace_root
54+
docker_image: *docker_image
55+
requires:
56+
- test
57+
filters:
58+
branches:
59+
ignore: master
60+
3161

3262
release:
3363
jobs:
@@ -37,6 +67,17 @@ workflows:
3767
only: /.*/
3868
branches:
3969
ignore: /.*/
70+
- docker-tools/tag:
71+
name: tag
72+
context: docker-deploy
73+
docker_username: ${DOCKER_USERNAME}
74+
docker_password: ${DOCKER_PASSWORD}
75+
docker_image: *docker_image
76+
filters:
77+
tags:
78+
only: /.*/
79+
branches:
80+
ignore: /.*/
4081

4182
executors:
4283
python-tox:
@@ -99,3 +140,5 @@ jobs:
99140
name: upload to pypi
100141
command: |
101142
python -m twine upload dist/*
143+
144+
build-docker:

Dockerfile

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
FROM python:3.8
2+
3+
RUN pip install --no-cache-dir --upgrade pip && \
4+
pip install --no-cache-dir gradient
5+
6+
ENV PAPERSPACE_API_KEY your_api_key_value
7+
ENV PAPERSPACE_WEB_URL https://console.paperspace.com
8+
ENV PAPERSPACE_CONFIG_HOST https://api.paperspace.io
9+
ENV PAPERSPACE_CONFIG_LOG_HOST https://logs.paperspace.io
10+
ENV PAPERSPACE_CONFIG_EXPERIMENTS_HOST https://services.paperspace.io/experiments/v1/
11+
ENV PAPERSPACE_CONFIG_EXPERIMENTS_HOST_V2 https://services.paperspace.io/experiments/v2/
12+
ENV PAPERSPACE_CONFIG_SERVICE_HOST https://services.paperspace.io
13+
14+
ENTRYPOINT ["gradient"]

gradient/api_sdk/clients/model_client.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,38 @@ def upload(self, path, name, model_type, model_summary=None, notes=None, tags=No
6161

6262
return model_id
6363

64+
def create(self, name, model_type, dataset_ref, model_summary=None, notes=None, tags=None, project_id=None):
65+
"""Create model
66+
67+
:param str name: Model name
68+
:param str model_type: Model Type
69+
:param str dataset_ref: Dataset ref to associate a model with
70+
:param dict|None model_summary: Dictionary describing model parameters like loss, accuracy, etc.
71+
:param str|None notes: Optional model description
72+
:param list[str] tags: List of tags
73+
:param str|None project_id: ID of a project
74+
75+
:return: ID of new model
76+
:rtype: str
77+
"""
78+
79+
model = models.Model(
80+
name=name,
81+
model_type=model_type,
82+
dataset_ref=dataset_ref,
83+
summary=json.dumps(model_summary) if model_summary else None,
84+
notes=notes,
85+
project_id=project_id,
86+
)
87+
88+
repository = self.build_repository(repositories.CreateModel)
89+
model_id = repository.create(model)
90+
91+
if tags:
92+
self.add_tags(entity_id=model_id, tags=tags)
93+
94+
return model_id
95+
6496
def get(self, model_id):
6597
"""Get model instance
6698

gradient/api_sdk/models/model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class Model(object):
1818
:param str deployment_state:
1919
:param str summary:
2020
:param str detail:
21+
:param str dataset_ref:
2122
"""
2223
id = attr.ib(type=str, default=None)
2324
name = attr.ib(type=str, default=None)
@@ -32,6 +33,7 @@ class Model(object):
3233
summary = attr.ib(type=dict, default=None)
3334
detail = attr.ib(type=dict, default=None)
3435
notes = attr.ib(type=str, default=None)
36+
dataset_ref = attr.ib(type=str, default=None)
3537

3638

3739
@attr.s

gradient/api_sdk/repositories/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from .machine_types import ListMachineTypes
1616
from .machines import CheckMachineAvailability, CreateMachine, CreateResource, StartMachine, StopMachine, \
1717
RestartMachine, GetMachine, UpdateMachine, GetMachineUtilization
18-
from .models import DeleteModel, ListModels, UploadModel, GetModel, ListModelFiles
18+
from .models import DeleteModel, ListModels, UploadModel, GetModel, ListModelFiles, CreateModel
1919
from .notebooks import CreateNotebook, DeleteNotebook, GetNotebook, ListNotebooks, GetNotebookMetrics, ListNotebookMetrics, \
2020
StreamNotebookMetrics, StopNotebook, StartNotebook, ForkNotebook, ListNotebookArtifacts, ListNotebookLogs
2121
from .projects import CreateProject, ListProjects, DeleteProject, GetProject

gradient/api_sdk/repositories/models.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,24 @@ def _delete_model(self, model_id):
107107
repository.delete(model_id)
108108

109109

110+
class CreateModel(GetBaseModelsApiUrlMixin, CreateResource):
111+
SERIALIZER_CLS = serializers.Model
112+
HANDLE_FIELD = "id"
113+
114+
def get_request_url(self, **kwargs):
115+
return "/mlModels/createModelV2"
116+
117+
def _get_request_params(self, kwargs):
118+
return kwargs
119+
120+
def _get_request_json(self, instance_dict):
121+
return None
122+
123+
def create(self, instance, data=None, path=None):
124+
model_id = super(CreateModel, self).create(instance, data=data, path=path)
125+
return model_id
126+
127+
110128
class GetModel(GetBaseModelsApiUrlMixin, GetResource):
111129
SERIALIZER_CLS = serializers.Model
112130

gradient/api_sdk/repositories/workflows.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,6 @@ def _send_create_request(self, **kwargs):
125125

126126
response = client.post(url, json=json_)
127127
gradient_response = http_client.GradientResponse.interpret_response(response)
128-
129128
json_formatted_str = json.dumps(gradient_response.data, indent=4)
130129
return gradient_response
131130

gradient/api_sdk/serializers/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ class Model(BaseSchema):
1919
summary = marshmallow.fields.Dict()
2020
detail = marshmallow.fields.Dict()
2121
notes = marshmallow.fields.Str()
22+
dataset_ref = marshmallow.fields.Str(dump_to="datasetRef", load_from="datasetRef")
2223

2324

2425
class ModelFileSchema(BaseSchema):

gradient/cli/models.py

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,69 @@ def delete_model(api_key, model_id, options_file):
6060
command.execute(model_id=model_id)
6161

6262

63+
@models_group.command("create", help="Create a model from an url or dataset id")
64+
@click.option(
65+
"--name",
66+
"name",
67+
required=True,
68+
help="Model name",
69+
cls=common.GradientOption,
70+
)
71+
@click.option(
72+
"--modelType",
73+
"model_type",
74+
required=True,
75+
type=ChoiceType(constants.MODEL_TYPES_MAP, case_sensitive=False),
76+
help="Model type",
77+
cls=common.GradientOption,
78+
)
79+
@click.option(
80+
"--datasetRef",
81+
"dataset_ref",
82+
required=True,
83+
help="Dataset ref to associate a model with",
84+
cls=common.GradientOption,
85+
)
86+
@click.option(
87+
"--projectId",
88+
"project_id",
89+
help="ID of a project",
90+
cls=common.GradientOption,
91+
)
92+
@click.option(
93+
"--modelSummary",
94+
"model_summary",
95+
type=json_string,
96+
help="Model summary",
97+
cls=common.GradientOption,
98+
)
99+
@click.option(
100+
"--notes",
101+
"notes",
102+
help="Additional notes",
103+
cls=common.GradientOption,
104+
)
105+
@click.option(
106+
"--tag",
107+
"tags",
108+
multiple=True,
109+
help="One or many tags that you want to add to experiment",
110+
cls=common.GradientOption
111+
)
112+
@click.option(
113+
"--tags",
114+
"tags_comma",
115+
help="Separated by comma tags that you want add to experiment",
116+
cls=common.GradientOption
117+
)
118+
@common.api_key_option
119+
@common.options_file
120+
def create_model(api_key, options_file, **model):
121+
model["tags"] = validate_comma_split_option(model.pop("tags_comma"), model.pop("tags"))
122+
command = models_commands.CreateModel(api_key=api_key)
123+
command.execute(**model)
124+
125+
63126
@models_group.command("upload", help="Upload a model file or directory")
64127
@click.argument(
65128
"PATH",

gradient/commands/models.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,16 @@ def execute(self, model_id, *args, **kwargs):
4949
self.logger.log("Model deleted")
5050

5151

52+
class CreateModel(GetModelsClientMixin, BaseCommand):
53+
SPINNER_MESSAGE = "Creating model"
54+
55+
def execute(self, **kwargs):
56+
with halo.Halo(text=self.SPINNER_MESSAGE, spinner="dots"):
57+
model_id = self.client.create(**kwargs)
58+
59+
self.logger.log("Model created with ID: {}".format(model_id))
60+
61+
5262
class UploadModel(GetModelsClientMixin, BaseCommand):
5363
SPINNER_MESSAGE = "Uploading model"
5464

0 commit comments

Comments
 (0)