Skip to content

Commit 285bc62

Browse files
authored
fix(sdk): expose datasets client as part of the sdk_client PLA-366 (#358)
1 parent d440a83 commit 285bc62

File tree

6 files changed

+65
-25
lines changed

6 files changed

+65
-25
lines changed

.gitignore

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,7 @@ ENV/
106106
.idea
107107
paperspace-python.zip
108108
*.env
109+
110+
bin/
111+
lib64
112+
share/

gradient/api_sdk/clients/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from .project_client import ProjectsClient
1414
from .secret_client import SecretsClient
1515
from .storage_provider_client import StorageProvidersClient
16-
from .sdk_client import SdkClient
1716
from .tensorboards_client import TensorboardClient
1817
from .workflow_client import WorkflowsClient
18+
19+
from .sdk_client import SdkClient
Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
from . import DeploymentsClient, ExperimentsClient, HyperparameterJobsClient, ModelsClient, ProjectsClient, \
2-
MachinesClient, NotebooksClient, SecretsClient
2+
MachinesClient, NotebooksClient, SecretsClient, DatasetsClient, MachineTypesClient, DatasetVersionsClient, \
3+
DatasetTagsClient, ClustersClient, StorageProvidersClient
34
from .job_client import JobsClient
45
from .workflow_client import WorkflowsClient
6+
from .tensorboards_client import TensorboardClient
57
from .. import logger as sdk_logger
68

79

@@ -11,13 +13,23 @@ def __init__(self, api_key, logger=sdk_logger.MuteLogger()):
1113
:param str api_key: API key
1214
:param sdk_logger.Logger logger:
1315
"""
14-
self.experiments = ExperimentsClient(api_key=api_key, logger=logger)
16+
self.clusters = ClustersClient(api_key=api_key, logger=logger)
17+
self.datasets = DatasetsClient(api_key=api_key, logger=logger)
18+
self.dataset_tags = DatasetTagsClient(api_key=api_key, logger=logger)
19+
self.dataset_versions = DatasetVersionsClient(
20+
api_key=api_key, logger=logger)
1521
self.deployments = DeploymentsClient(api_key=api_key, logger=logger)
16-
self.hyperparameters = HyperparameterJobsClient(api_key=api_key, logger=logger)
17-
self.models = ModelsClient(api_key=api_key, logger=logger)
22+
self.experiments = ExperimentsClient(api_key=api_key, logger=logger)
23+
self.hyperparameters = HyperparameterJobsClient(
24+
api_key=api_key, logger=logger)
1825
self.jobs = JobsClient(api_key=api_key, logger=logger)
19-
self.projects = ProjectsClient(api_key=api_key, logger=logger)
26+
self.machine_types = MachineTypesClient(api_key=api_key, logger=logger)
2027
self.machines = MachinesClient(api_key=api_key, logger=logger)
28+
self.models = ModelsClient(api_key=api_key, logger=logger)
2129
self.notebooks = NotebooksClient(api_key=api_key, logger=logger)
30+
self.projects = ProjectsClient(api_key=api_key, logger=logger)
2231
self.secrets = SecretsClient(api_key=api_key, logger=logger)
32+
self.storage_providers = StorageProvidersClient(
33+
api_key=api_key, logger=logger)
34+
self.tensorboards = TensorboardClient(api_key=api_key, logger=logger)
2335
self.workflows = WorkflowsClient(api_key=api_key, logger=logger)

gradient/api_sdk/clients/tensorboards_client.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def add_experiments(self, id, added_experiments):
170170
:raises: ResourceFetchingError: When there is problem with response from API
171171
"""
172172
repository = self.build_repository(repositories.UpdateTensorboard)
173-
tensorboard = repository.update(id=id, added_experiments=added_experiments)
173+
tensorboard = repository.update(
174+
id=id, added_experiments=added_experiments)
174175
return tensorboard
175176

176177
def remove_experiments(self, id, removed_experiments):
@@ -197,7 +198,8 @@ def remove_experiments(self, id, removed_experiments):
197198
:raises: ResourceFetchingError: When there is problem with response from API
198199
"""
199200
repository = self.build_repository(repositories.UpdateTensorboard)
200-
tensorboard = repository.update(id=id, removed_experiments=removed_experiments)
201+
tensorboard = repository.update(
202+
id=id, removed_experiments=removed_experiments)
201203
return tensorboard
202204

203205
def delete(self, id):

gradient/commands/experiments.py

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def maybe_add_to_tensorboard(self, tensorboard_id, experiment_id):
5454

5555
tensorboards = self._get_tensorboards()
5656
if len(tensorboards) == 1:
57-
self._add_experiment_to_tensorboard(tensorboards[0].id, experiment_id)
57+
self._add_experiment_to_tensorboard(
58+
tensorboards[0].id, experiment_id)
5859
else:
5960
self._create_tensorboard_with_experiment(experiment_id)
6061

@@ -64,15 +65,17 @@ def _add_experiment_to_tensorboard(self, tensorboard_id, experiment_id):
6465
:param str tensorboard_id:
6566
:param str experiment_id:
6667
"""
67-
command = tensorboards_commands.AddExperimentToTensorboard(api_key=self.api_key)
68+
command = tensorboards_commands.AddExperimentToTensorboard(
69+
api_key=self.api_key)
6870
command.execute(tensorboard_id, [experiment_id])
6971

7072
def _get_tensorboards(self):
7173
"""Get tensorboards
7274
7375
:rtype: list[api_sdk.Tensorboard]
7476
"""
75-
tensorboard_client = TensorboardClient(api_key=self.api_key, logger=self.logger)
77+
tensorboard_client = TensorboardClient(
78+
api_key=self.api_key, logger=self.logger)
7679
tensorboards = tensorboard_client.list()
7780
return tensorboards
7881

@@ -81,7 +84,8 @@ def _create_tensorboard_with_experiment(self, experiment_id):
8184
8285
:param str experiment_id:
8386
"""
84-
command = tensorboards_commands.CreateTensorboardCommand(api_key=self.api_key)
87+
command = tensorboards_commands.CreateTensorboardCommand(
88+
api_key=self.api_key)
8589
command.execute(experiments=[experiment_id])
8690

8791

@@ -101,14 +105,18 @@ def execute(self, json_, add_to_tensorboard=False):
101105
with halo.Halo(text=self.SPINNER_MESSAGE, spinner="dots"):
102106
experiment_id = self._create(json_)
103107

104-
self.logger.log(self.CREATE_SUCCESS_MESSAGE_TEMPLATE.format(experiment_id))
105-
self.logger.log(self.get_instance_url(experiment_id, json_["project_id"]))
108+
self.logger.log(
109+
self.CREATE_SUCCESS_MESSAGE_TEMPLATE.format(experiment_id))
110+
self.logger.log(self.get_instance_url(
111+
experiment_id, json_["project_id"]))
106112

107-
self._maybe_add_to_tensorboard(add_to_tensorboard, experiment_id, self.api_key)
113+
self._maybe_add_to_tensorboard(
114+
add_to_tensorboard, experiment_id, self.api_key)
108115
return experiment_id
109116

110117
def get_instance_url(self, instance_id, project_id):
111-
url = concatenate_urls(config.WEB_URL, "{}/projects/{}/experiments/{}".format(self.get_namespace(), project_id, instance_id))
118+
url = concatenate_urls(config.WEB_URL, "{}/projects/{}/experiments/{}".format(
119+
self.get_namespace(), project_id, instance_id))
112120
return url
113121

114122
def _handle_workspace(self, instance_dict):
@@ -129,7 +137,8 @@ def _maybe_add_to_tensorboard(self, tensorboard_id, experiment_id, api_key):
129137
"""
130138
if tensorboard_id is not False:
131139
tensorboard_handler = TensorboardHandler(api_key)
132-
tensorboard_handler.maybe_add_to_tensorboard(tensorboard_id, experiment_id)
140+
tensorboard_handler.maybe_add_to_tensorboard(
141+
tensorboard_id, experiment_id)
133142

134143
@staticmethod
135144
def _handle_dataset_data(json_):
@@ -151,12 +160,13 @@ def _handle_dataset_data(json_):
151160
return
152161
else:
153162
datasets_len = max(len(datasets[0]), len(datasets[1]))
154-
other_dataset_param_max_len = max(len(elem) for elem in datasets[2:])
163+
other_dataset_param_max_len = max(
164+
len(elem) for elem in datasets[2:])
155165
if datasets_len < other_dataset_param_max_len:
156166
# there no point in defining n+1 dataset parameters of one type for n datasets
157167
raise click.BadParameter(
158168
"Too many dataset parameter sets ({}) for {} dataset URIs. Forgot to add one more dataset URI?"
159-
.format(other_dataset_param_max_len, datasets_len))
169+
.format(other_dataset_param_max_len, datasets_len))
160170

161171
datasets = [none_strings_to_none_objects(d) for d in datasets]
162172

@@ -194,7 +204,8 @@ def _create(self, json_):
194204

195205
class CreateMpiMultiNodeExperimentCommand(BaseCreateExperimentCommandMixin, BaseExperimentCommand):
196206
def _create(self, json_):
197-
json_.pop("experiment_type_id", None) # for MPI there is no experiment_type_id parameter in client method
207+
# for MPI there is no experiment_type_id parameter in client method
208+
json_.pop("experiment_type_id", None)
198209
handle = self.client.create_mpi_multi_node(**json_)
199210
return handle
200211

@@ -213,7 +224,8 @@ class CreateAndStartMpiMultiNodeExperimentCommand(BaseCreateExperimentCommandMix
213224
CREATE_SUCCESS_MESSAGE_TEMPLATE = "New experiment created and started with ID: {}"
214225

215226
def _create(self, json_):
216-
json_.pop("experiment_type_id", None) # for MPI there is no experiment_type_id parameter in client method
227+
# for MPI there is no experiment_type_id parameter in client method
228+
json_.pop("experiment_type_id", None)
217229
handle = self.client.run_mpi_multi_node(**json_)
218230
return handle
219231

@@ -283,7 +295,8 @@ def _get_table_data(self, experiment):
283295
if experiment.experiment_type_id == constants.ExperimentType.MPI_MULTI_NODE:
284296
return self._get_multi_node_mpi_data(experiment)
285297

286-
raise ValueError("Wrong experiment type: {}".format(experiment.experiment_type_id))
298+
raise ValueError("Wrong experiment type: {}".format(
299+
experiment.experiment_type_id))
287300

288301
@staticmethod
289302
def _get_single_node_data(experiment):
@@ -323,13 +336,15 @@ def _get_multi_node_grpc_data(experiment):
323336
("Artifact directory", experiment.artifact_directory),
324337
("Cluster ID", experiment.cluster_id),
325338
("Experiment Env", experiment.experiment_env),
326-
("Experiment Type", constants.ExperimentType.get_type_str(experiment.experiment_type_id)),
339+
("Experiment Type", constants.ExperimentType.get_type_str(
340+
experiment.experiment_type_id)),
327341
("Model Type", experiment.model_type),
328342
("Model Path", experiment.model_path),
329343
("Parameter Server Command", experiment.parameter_server_command),
330344
("Parameter Server Container", experiment.parameter_server_container),
331345
("Parameter Server Count", experiment.parameter_server_count),
332-
("Parameter Server Machine Type", experiment.parameter_server_machine_type),
346+
("Parameter Server Machine Type",
347+
experiment.parameter_server_machine_type),
333348
("Ports", experiment.ports),
334349
("Project ID", experiment.project_id),
335350
("Worker Command", experiment.worker_command),
@@ -356,7 +371,8 @@ def _get_multi_node_mpi_data(experiment):
356371
("Artifact directory", experiment.artifact_directory),
357372
("Cluster ID", experiment.cluster_id),
358373
("Experiment Env", experiment.experiment_env),
359-
("Experiment Type", constants.ExperimentType.get_type_str(experiment.experiment_type_id)),
374+
("Experiment Type", constants.ExperimentType.get_type_str(
375+
experiment.experiment_type_id)),
360376
("Model Type", experiment.model_type),
361377
("Model Path", experiment.model_path),
362378
("Master Command", experiment.master_command),
@@ -428,6 +444,7 @@ def execute(self, experiment_id, start, end, interval, built_in_metrics, *args,
428444
formatted_metrics = json.dumps(metrics, indent=2, sort_keys=True)
429445
self.logger.log(formatted_metrics)
430446

447+
431448
class ListExperimentMetricsCommand(BaseExperimentCommand):
432449
def execute(self, experiment_id, start, end, interval, *args, **kwargs):
433450
metrics = self.client.list_metrics(
@@ -439,5 +456,6 @@ def execute(self, experiment_id, start, end, interval, *args, **kwargs):
439456
formatted_metrics = json.dumps(metrics, indent=2, sort_keys=True)
440457
self.logger.log(formatted_metrics)
441458

459+
442460
class StreamExperimentMetricsCommand(StreamMetricsCommand, BaseExperimentCommand):
443461
pass

pyvenv.cfg

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
home = /home/linuxbrew/.linuxbrew/bin
2+
include-system-site-packages = false
3+
version = 3.9.0

0 commit comments

Comments
 (0)