Skip to content

Commit b31869a

Browse files
authored
Merge pull request #49 from comet-ml/CM-1830-register-model
[Cm-1830] Implement API.registerModel
2 parents 46f5b24 + fc4ced2 commit b31869a

31 files changed

+1059
-12
lines changed

comet-examples/src/main/java/ml/comet/examples/LogModelExample.java

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
package ml.comet.examples;
22

3+
import ml.comet.experiment.CometApi;
34
import ml.comet.experiment.ExperimentBuilder;
45
import ml.comet.experiment.OnlineExperiment;
56
import ml.comet.experiment.asset.LoggedExperimentAsset;
67
import ml.comet.experiment.context.ExperimentContext;
8+
import ml.comet.experiment.impl.utils.ModelUtils;
9+
import ml.comet.experiment.registrymodel.Model;
10+
import ml.comet.experiment.registrymodel.ModelRegistryRecord;
711
import org.awaitility.Awaitility;
812

913
import java.io.File;
1014
import java.nio.charset.StandardCharsets;
1115
import java.nio.file.Path;
1216
import java.util.ArrayList;
17+
import java.util.Collections;
1318
import java.util.List;
1419
import java.util.Map;
1520
import java.util.concurrent.TimeUnit;
@@ -18,7 +23,7 @@
1823
import static ml.comet.experiment.impl.asset.AssetType.MODEL_ELEMENT;
1924

2025
/**
21-
* Provides examples of working with models logging and retrieval.
26+
* Provides examples of working with models logging, registering, and retrieval.
2227
*
2328
* <p>To run from command line execute the following at the root of this module:
2429
* <pre>
@@ -34,6 +39,9 @@ public class LogModelExample implements BaseExample {
3439
private static final String SOME_MODEL_NAME = "someModelNameExample";
3540
private static final String SOME_MODEL_LOGICAL_PATH = "someExampleModelData.dat";
3641
private static final String SOME_MODEL_DATA = "some model data string";
42+
private static final String SOME_MODEL_DESCRIPTION = "LogModelExample model";
43+
private static final String SOME_MODEL_VERSION = "1.0.0";
44+
private static final String SOME_MODEL_VERSION_UP = "1.0.1";
3745

3846
/**
3947
* The main entry point to the example.
@@ -91,6 +99,38 @@ private static void run(OnlineExperiment experiment) throws Exception {
9199
System.out.printf("Retrieved %d logged assets of the model '%s':\n", assets.size(), SOME_MODEL_NAME);
92100
assets.forEach(loggedExperimentAsset -> System.out.printf("\t%s\n", loggedExperimentAsset));
93101

102+
// Register experiment model in the Comet registry
103+
//
104+
try (CometApi api = ExperimentBuilder.CometApi().build()) {
105+
String registryName = String.format("%s-%d", SOME_MODEL_NAME, System.currentTimeMillis());
106+
registryName = ModelUtils.createRegistryModelName(registryName);
107+
System.out.printf("\nRegistering model '%s' in the Comet model registry under workspace '%s'.\n",
108+
registryName, experiment.getWorkspaceName());
109+
110+
Model model = Model.newModel(SOME_MODEL_NAME)
111+
.withRegistryName(registryName)
112+
.withDescription(SOME_MODEL_DESCRIPTION)
113+
.withStages(Collections.singletonList("example"))
114+
.withVersion(SOME_MODEL_VERSION).build();
115+
ModelRegistryRecord record = api.registerModel(model, experiment.getExperimentKey());
116+
117+
System.out.printf("The experiment's model was successfully registered under record: %s\n\n", record);
118+
119+
120+
// create new version of the registered model
121+
//
122+
System.out.printf("Updating model '%s' in the Comet model registry with new version '%s'.\n",
123+
registryName, SOME_MODEL_VERSION_UP);
124+
Model updatedModel = Model.newModel(SOME_MODEL_NAME)
125+
.withRegistryName(registryName)
126+
.withDescription(SOME_MODEL_DESCRIPTION)
127+
.withStages(Collections.singletonList("production"))
128+
.withVersion(SOME_MODEL_VERSION_UP).build();
129+
130+
record = api.registerModel(updatedModel, experiment.getExperimentKey());
131+
System.out.printf("The experiment's model was successfully updated with record: %s\n\n", record);
132+
}
133+
94134
System.out.println("===== Experiment completed ====");
95135
}
96136
}

comet-java-client/src/main/java/ml/comet/experiment/CometApi.java

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import ml.comet.experiment.model.ExperimentMetadata;
44
import ml.comet.experiment.model.Project;
5+
import ml.comet.experiment.registrymodel.Model;
6+
import ml.comet.experiment.registrymodel.ModelRegistryRecord;
57

68
import java.io.Closeable;
79
import java.util.List;
@@ -16,23 +18,32 @@ public interface CometApi extends Closeable {
1618
/**
1719
* Gets all workspaces available for current API key.
1820
*
19-
* @return List of workspace names
21+
* @return the list of workspace names
2022
*/
2123
List<String> getAllWorkspaces();
2224

2325
/**
24-
* Gets all project DTOs under specified workspace name.
26+
* Gets all projects under specified workspace name.
2527
*
2628
* @param workspaceName workspace name
27-
* @return List of project DTOs
29+
* @return the list of projects
2830
*/
2931
List<Project> getAllProjects(String workspaceName);
3032

3133
/**
32-
* Gets all experiment DTOs under specified project id.
34+
* Gets metadata of all experiments created under specified project id.
3335
*
34-
* @param projectId Project id
35-
* @return List of experiment DTOs
36+
* @param projectId the ID of the project.
37+
* @return the list of experiments' metadata objects.
3638
*/
3739
List<ExperimentMetadata> getAllExperiments(String projectId);
40+
41+
/**
42+
* Register model defined in the specified experiment in the Comet's model registry.
43+
*
44+
* @param model the {@link Model} to be registered.
45+
* @param experimentKey the identifier of the experiment where model assets was logged.
46+
* @return the {@link ModelRegistryRecord} instance holding information about model registry record.
47+
*/
48+
ModelRegistryRecord registerModel(Model model, String experimentKey);
3849
}

comet-java-client/src/main/java/ml/comet/experiment/ExperimentBuilder.java

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,10 @@
22

33
import lombok.experimental.UtilityClass;
44
import ml.comet.experiment.builder.ApiExperimentBuilder;
5+
import ml.comet.experiment.builder.CometApiBuilder;
56
import ml.comet.experiment.builder.OnlineExperimentBuilder;
67
import ml.comet.experiment.impl.ApiExperimentImpl;
8+
import ml.comet.experiment.impl.CometApiImpl;
79
import ml.comet.experiment.impl.OnlineExperimentImpl;
810

911
/**
@@ -40,4 +42,15 @@ public static OnlineExperimentBuilder OnlineExperiment() {
4042
public static ApiExperimentBuilder ApiExperiment() {
4143
return ApiExperimentImpl.builder();
4244
}
45+
46+
/**
47+
* The factory to create instance of the {@link CometApiBuilder} which can be used to configure
48+
* and create fully initialized instance of the {@link CometApi}.
49+
*
50+
* @return the instance of the {@link CometApiBuilder}.
51+
*/
52+
@SuppressWarnings("checkstyle:MethodName")
53+
public static CometApiBuilder CometApi() {
54+
return CometApiImpl.builder();
55+
}
4356
}

comet-java-client/src/main/java/ml/comet/experiment/impl/CometApiImpl.java

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package ml.comet.experiment.impl;
22

3+
import io.reactivex.rxjava3.core.Observable;
34
import lombok.NonNull;
45
import lombok.SneakyThrows;
56
import ml.comet.experiment.CometApi;
@@ -8,9 +9,18 @@
89
import ml.comet.experiment.impl.config.CometConfig;
910
import ml.comet.experiment.impl.http.Connection;
1011
import ml.comet.experiment.impl.http.ConnectionInitializer;
12+
import ml.comet.experiment.impl.rest.ExperimentModelListResponse;
13+
import ml.comet.experiment.impl.rest.ExperimentModelResponse;
14+
import ml.comet.experiment.impl.rest.RegistryModelCreateRequest;
15+
import ml.comet.experiment.impl.rest.RegistryModelItemCreateRequest;
16+
import ml.comet.experiment.impl.rest.RegistryModelOverviewListResponse;
1117
import ml.comet.experiment.impl.utils.CometUtils;
18+
import ml.comet.experiment.impl.utils.DataModelUtils;
1219
import ml.comet.experiment.model.ExperimentMetadata;
1320
import ml.comet.experiment.model.Project;
21+
import ml.comet.experiment.registrymodel.Model;
22+
import ml.comet.experiment.registrymodel.ModelNotFoundException;
23+
import ml.comet.experiment.registrymodel.ModelRegistryRecord;
1424
import org.apache.commons.lang3.StringUtils;
1525
import org.slf4j.Logger;
1626
import org.slf4j.LoggerFactory;
@@ -20,10 +30,19 @@
2030
import java.util.ArrayList;
2131
import java.util.List;
2232
import java.util.Objects;
33+
import java.util.Optional;
34+
import java.util.stream.Collectors;
2335

2436
import static ml.comet.experiment.impl.config.CometConfig.COMET_API_KEY;
2537
import static ml.comet.experiment.impl.config.CometConfig.COMET_BASE_URL;
2638
import static ml.comet.experiment.impl.config.CometConfig.COMET_MAX_AUTH_RETRIES;
39+
import static ml.comet.experiment.impl.resources.LogMessages.EXPERIMENT_HAS_NO_MODELS;
40+
import static ml.comet.experiment.impl.resources.LogMessages.FAILED_TO_FIND_EXPERIMENT_MODEL_BY_NAME;
41+
import static ml.comet.experiment.impl.resources.LogMessages.MODEL_REGISTERED_IN_WORKSPACE;
42+
import static ml.comet.experiment.impl.resources.LogMessages.MODEL_VERSION_CREATED_IN_WORKSPACE;
43+
import static ml.comet.experiment.impl.resources.LogMessages.UPDATE_REGISTRY_MODEL_DESCRIPTION_IGNORED;
44+
import static ml.comet.experiment.impl.resources.LogMessages.UPDATE_REGISTRY_MODEL_IS_PUBLIC_IGNORED;
45+
import static ml.comet.experiment.impl.resources.LogMessages.getString;
2746

2847
/**
2948
* The implementation of the {@link CometApi}.
@@ -91,6 +110,69 @@ public List<ExperimentMetadata> getAllExperiments(@NonNull String projectId) {
91110
ArrayList::addAll);
92111
}
93112

113+
@Override
114+
public ModelRegistryRecord registerModel(@NonNull final Model model, @NonNull final String experimentKey) {
115+
// get list of experiment models
116+
List<ExperimentModelResponse> experimentModels = this.restApiClient
117+
.getExperimentModels(experimentKey)
118+
.map(ExperimentModelListResponse::getModels)
119+
.blockingGet();
120+
121+
// check that experiment has models registered
122+
if (experimentModels == null || experimentModels.size() == 0) {
123+
throw new ModelNotFoundException(getString(EXPERIMENT_HAS_NO_MODELS, experimentKey));
124+
}
125+
126+
// check that experiment has our model in the list
127+
Optional<ExperimentModelResponse> details = experimentModels.stream()
128+
.filter(modelResponse -> Objects.equals(modelResponse.getModelName(), model.getName()))
129+
.findFirst();
130+
if (!details.isPresent()) {
131+
String names = experimentModels.stream()
132+
.map(ExperimentModelResponse::getModelName)
133+
.collect(Collectors.joining(", "));
134+
throw new ModelNotFoundException(
135+
getString(FAILED_TO_FIND_EXPERIMENT_MODEL_BY_NAME, model.getName(), names));
136+
}
137+
138+
// set model fields
139+
final RegistryModelImpl modelImpl = (RegistryModelImpl) model;
140+
modelImpl.setExperimentModelId(details.get().getExperimentModelId());
141+
142+
// check if model already registered in the experiment's workspace records
143+
Boolean modelInRegistry = this.restApiClient.getMetadata(experimentKey)
144+
.concatMap(experimentMetadataRest -> {
145+
modelImpl.setWorkspace(experimentMetadataRest.getWorkspaceName());
146+
return this.restApiClient.getRegistryModelsForWorkspace(experimentMetadataRest.getWorkspaceName());
147+
})
148+
.map(RegistryModelOverviewListResponse::getRegistryModels)
149+
.flatMapObservable(Observable::fromIterable)
150+
.any(registryModel -> Objects.equals(registryModel.getModelName(), model.getRegistryName()))
151+
.blockingGet();
152+
153+
ModelRegistryRecord registry;
154+
if (modelInRegistry) {
155+
// create new version of the model
156+
if (StringUtils.isNotBlank(modelImpl.getDescription())) {
157+
this.logger.warn(getString(UPDATE_REGISTRY_MODEL_DESCRIPTION_IGNORED));
158+
}
159+
if (modelImpl.getIsPublic() != null) {
160+
this.logger.warn(getString(UPDATE_REGISTRY_MODEL_IS_PUBLIC_IGNORED));
161+
}
162+
RegistryModelItemCreateRequest request = DataModelUtils.createRegistryModelItemCreateRequest(modelImpl);
163+
registry = this.restApiClient.createRegistryModelItem(request).blockingGet().toModelRegistry();
164+
this.logger.info(getString(MODEL_VERSION_CREATED_IN_WORKSPACE,
165+
model.getVersion(), model.getRegistryName(), model.getWorkspace()));
166+
} else {
167+
// create model's registry record
168+
RegistryModelCreateRequest request = DataModelUtils.createRegistryModelCreateRequest(modelImpl);
169+
registry = this.restApiClient.createRegistryModel(request).blockingGet().toModelRegistry();
170+
this.logger.info(getString(MODEL_REGISTERED_IN_WORKSPACE,
171+
model.getRegistryName(), model.getVersion(), model.getWorkspace()));
172+
}
173+
return registry;
174+
}
175+
94176
/**
95177
* Release all resources hold by this instance, such as connection to the Comet server.
96178
*
@@ -117,6 +199,10 @@ void init() {
117199
this.restApiClient = new RestApiClient(this.connection);
118200
}
119201

202+
RestApiClient getRestApiClient() {
203+
return this.restApiClient;
204+
}
205+
120206
/**
121207
* Returns builder to be used to properly create instance of this class.
122208
*

0 commit comments

Comments
 (0)