@@ -54,7 +54,8 @@ def maybe_add_to_tensorboard(self, tensorboard_id, experiment_id):
54
54
55
55
tensorboards = self ._get_tensorboards ()
56
56
if len (tensorboards ) == 1 :
57
- self ._add_experiment_to_tensorboard (tensorboards [0 ].id , experiment_id )
57
+ self ._add_experiment_to_tensorboard (
58
+ tensorboards [0 ].id , experiment_id )
58
59
else :
59
60
self ._create_tensorboard_with_experiment (experiment_id )
60
61
@@ -64,15 +65,17 @@ def _add_experiment_to_tensorboard(self, tensorboard_id, experiment_id):
64
65
:param str tensorboard_id:
65
66
:param str experiment_id:
66
67
"""
67
- command = tensorboards_commands .AddExperimentToTensorboard (api_key = self .api_key )
68
+ command = tensorboards_commands .AddExperimentToTensorboard (
69
+ api_key = self .api_key )
68
70
command .execute (tensorboard_id , [experiment_id ])
69
71
70
72
def _get_tensorboards (self ):
71
73
"""Get tensorboards
72
74
73
75
:rtype: list[api_sdk.Tensorboard]
74
76
"""
75
- tensorboard_client = TensorboardClient (api_key = self .api_key , logger = self .logger )
77
+ tensorboard_client = TensorboardClient (
78
+ api_key = self .api_key , logger = self .logger )
76
79
tensorboards = tensorboard_client .list ()
77
80
return tensorboards
78
81
@@ -81,7 +84,8 @@ def _create_tensorboard_with_experiment(self, experiment_id):
81
84
82
85
:param str experiment_id:
83
86
"""
84
- command = tensorboards_commands .CreateTensorboardCommand (api_key = self .api_key )
87
+ command = tensorboards_commands .CreateTensorboardCommand (
88
+ api_key = self .api_key )
85
89
command .execute (experiments = [experiment_id ])
86
90
87
91
@@ -101,14 +105,18 @@ def execute(self, json_, add_to_tensorboard=False):
101
105
with halo .Halo (text = self .SPINNER_MESSAGE , spinner = "dots" ):
102
106
experiment_id = self ._create (json_ )
103
107
104
- self .logger .log (self .CREATE_SUCCESS_MESSAGE_TEMPLATE .format (experiment_id ))
105
- self .logger .log (self .get_instance_url (experiment_id , json_ ["project_id" ]))
108
+ self .logger .log (
109
+ self .CREATE_SUCCESS_MESSAGE_TEMPLATE .format (experiment_id ))
110
+ self .logger .log (self .get_instance_url (
111
+ experiment_id , json_ ["project_id" ]))
106
112
107
- self ._maybe_add_to_tensorboard (add_to_tensorboard , experiment_id , self .api_key )
113
+ self ._maybe_add_to_tensorboard (
114
+ add_to_tensorboard , experiment_id , self .api_key )
108
115
return experiment_id
109
116
110
117
def get_instance_url (self , instance_id , project_id ):
111
- url = concatenate_urls (config .WEB_URL , "{}/projects/{}/experiments/{}" .format (self .get_namespace (), project_id , instance_id ))
118
+ url = concatenate_urls (config .WEB_URL , "{}/projects/{}/experiments/{}" .format (
119
+ self .get_namespace (), project_id , instance_id ))
112
120
return url
113
121
114
122
def _handle_workspace (self , instance_dict ):
@@ -129,7 +137,8 @@ def _maybe_add_to_tensorboard(self, tensorboard_id, experiment_id, api_key):
129
137
"""
130
138
if tensorboard_id is not False :
131
139
tensorboard_handler = TensorboardHandler (api_key )
132
- tensorboard_handler .maybe_add_to_tensorboard (tensorboard_id , experiment_id )
140
+ tensorboard_handler .maybe_add_to_tensorboard (
141
+ tensorboard_id , experiment_id )
133
142
134
143
@staticmethod
135
144
def _handle_dataset_data (json_ ):
@@ -151,12 +160,13 @@ def _handle_dataset_data(json_):
151
160
return
152
161
else :
153
162
datasets_len = max (len (datasets [0 ]), len (datasets [1 ]))
154
- other_dataset_param_max_len = max (len (elem ) for elem in datasets [2 :])
163
+ other_dataset_param_max_len = max (
164
+ len (elem ) for elem in datasets [2 :])
155
165
if datasets_len < other_dataset_param_max_len :
156
166
# there no point in defining n+1 dataset parameters of one type for n datasets
157
167
raise click .BadParameter (
158
168
"Too many dataset parameter sets ({}) for {} dataset URIs. Forgot to add one more dataset URI?"
159
- .format (other_dataset_param_max_len , datasets_len ))
169
+ .format (other_dataset_param_max_len , datasets_len ))
160
170
161
171
datasets = [none_strings_to_none_objects (d ) for d in datasets ]
162
172
@@ -194,7 +204,8 @@ def _create(self, json_):
194
204
195
205
class CreateMpiMultiNodeExperimentCommand (BaseCreateExperimentCommandMixin , BaseExperimentCommand ):
196
206
def _create (self , json_ ):
197
- json_ .pop ("experiment_type_id" , None ) # for MPI there is no experiment_type_id parameter in client method
207
+ # for MPI there is no experiment_type_id parameter in client method
208
+ json_ .pop ("experiment_type_id" , None )
198
209
handle = self .client .create_mpi_multi_node (** json_ )
199
210
return handle
200
211
@@ -213,7 +224,8 @@ class CreateAndStartMpiMultiNodeExperimentCommand(BaseCreateExperimentCommandMix
213
224
CREATE_SUCCESS_MESSAGE_TEMPLATE = "New experiment created and started with ID: {}"
214
225
215
226
def _create (self , json_ ):
216
- json_ .pop ("experiment_type_id" , None ) # for MPI there is no experiment_type_id parameter in client method
227
+ # for MPI there is no experiment_type_id parameter in client method
228
+ json_ .pop ("experiment_type_id" , None )
217
229
handle = self .client .run_mpi_multi_node (** json_ )
218
230
return handle
219
231
@@ -283,7 +295,8 @@ def _get_table_data(self, experiment):
283
295
if experiment .experiment_type_id == constants .ExperimentType .MPI_MULTI_NODE :
284
296
return self ._get_multi_node_mpi_data (experiment )
285
297
286
- raise ValueError ("Wrong experiment type: {}" .format (experiment .experiment_type_id ))
298
+ raise ValueError ("Wrong experiment type: {}" .format (
299
+ experiment .experiment_type_id ))
287
300
288
301
@staticmethod
289
302
def _get_single_node_data (experiment ):
@@ -323,13 +336,15 @@ def _get_multi_node_grpc_data(experiment):
323
336
("Artifact directory" , experiment .artifact_directory ),
324
337
("Cluster ID" , experiment .cluster_id ),
325
338
("Experiment Env" , experiment .experiment_env ),
326
- ("Experiment Type" , constants .ExperimentType .get_type_str (experiment .experiment_type_id )),
339
+ ("Experiment Type" , constants .ExperimentType .get_type_str (
340
+ experiment .experiment_type_id )),
327
341
("Model Type" , experiment .model_type ),
328
342
("Model Path" , experiment .model_path ),
329
343
("Parameter Server Command" , experiment .parameter_server_command ),
330
344
("Parameter Server Container" , experiment .parameter_server_container ),
331
345
("Parameter Server Count" , experiment .parameter_server_count ),
332
- ("Parameter Server Machine Type" , experiment .parameter_server_machine_type ),
346
+ ("Parameter Server Machine Type" ,
347
+ experiment .parameter_server_machine_type ),
333
348
("Ports" , experiment .ports ),
334
349
("Project ID" , experiment .project_id ),
335
350
("Worker Command" , experiment .worker_command ),
@@ -356,7 +371,8 @@ def _get_multi_node_mpi_data(experiment):
356
371
("Artifact directory" , experiment .artifact_directory ),
357
372
("Cluster ID" , experiment .cluster_id ),
358
373
("Experiment Env" , experiment .experiment_env ),
359
- ("Experiment Type" , constants .ExperimentType .get_type_str (experiment .experiment_type_id )),
374
+ ("Experiment Type" , constants .ExperimentType .get_type_str (
375
+ experiment .experiment_type_id )),
360
376
("Model Type" , experiment .model_type ),
361
377
("Model Path" , experiment .model_path ),
362
378
("Master Command" , experiment .master_command ),
@@ -428,6 +444,7 @@ def execute(self, experiment_id, start, end, interval, built_in_metrics, *args,
428
444
formatted_metrics = json .dumps (metrics , indent = 2 , sort_keys = True )
429
445
self .logger .log (formatted_metrics )
430
446
447
+
431
448
class ListExperimentMetricsCommand (BaseExperimentCommand ):
432
449
def execute (self , experiment_id , start , end , interval , * args , ** kwargs ):
433
450
metrics = self .client .list_metrics (
@@ -439,5 +456,6 @@ def execute(self, experiment_id, start, end, interval, *args, **kwargs):
439
456
formatted_metrics = json .dumps (metrics , indent = 2 , sort_keys = True )
440
457
self .logger .log (formatted_metrics )
441
458
459
+
442
460
class StreamExperimentMetricsCommand (StreamMetricsCommand , BaseExperimentCommand ):
443
461
pass
0 commit comments