11package ml .comet .experiment .impl ;
22
3+ import io .reactivex .rxjava3 .core .Observable ;
34import lombok .NonNull ;
45import lombok .SneakyThrows ;
56import ml .comet .experiment .CometApi ;
89import ml .comet .experiment .impl .config .CometConfig ;
910import ml .comet .experiment .impl .http .Connection ;
1011import ml .comet .experiment .impl .http .ConnectionInitializer ;
12+ import ml .comet .experiment .impl .rest .ExperimentModelListResponse ;
13+ import ml .comet .experiment .impl .rest .ExperimentModelResponse ;
14+ import ml .comet .experiment .impl .rest .RegistryModelCreateRequest ;
15+ import ml .comet .experiment .impl .rest .RegistryModelItemCreateRequest ;
16+ import ml .comet .experiment .impl .rest .RegistryModelOverviewListResponse ;
1117import ml .comet .experiment .impl .utils .CometUtils ;
18+ import ml .comet .experiment .impl .utils .DataModelUtils ;
1219import ml .comet .experiment .model .ExperimentMetadata ;
1320import ml .comet .experiment .model .Project ;
21+ import ml .comet .experiment .registrymodel .Model ;
22+ import ml .comet .experiment .registrymodel .ModelNotFoundException ;
23+ import ml .comet .experiment .registrymodel .ModelRegistryRecord ;
1424import org .apache .commons .lang3 .StringUtils ;
1525import org .slf4j .Logger ;
1626import org .slf4j .LoggerFactory ;
2030import java .util .ArrayList ;
2131import java .util .List ;
2232import java .util .Objects ;
33+ import java .util .Optional ;
34+ import java .util .stream .Collectors ;
2335
2436import static ml .comet .experiment .impl .config .CometConfig .COMET_API_KEY ;
2537import static ml .comet .experiment .impl .config .CometConfig .COMET_BASE_URL ;
2638import static ml .comet .experiment .impl .config .CometConfig .COMET_MAX_AUTH_RETRIES ;
39+ import static ml .comet .experiment .impl .resources .LogMessages .EXPERIMENT_HAS_NO_MODELS ;
40+ import static ml .comet .experiment .impl .resources .LogMessages .FAILED_TO_FIND_EXPERIMENT_MODEL_BY_NAME ;
41+ import static ml .comet .experiment .impl .resources .LogMessages .MODEL_REGISTERED_IN_WORKSPACE ;
42+ import static ml .comet .experiment .impl .resources .LogMessages .MODEL_VERSION_CREATED_IN_WORKSPACE ;
43+ import static ml .comet .experiment .impl .resources .LogMessages .UPDATE_REGISTRY_MODEL_DESCRIPTION_IGNORED ;
44+ import static ml .comet .experiment .impl .resources .LogMessages .UPDATE_REGISTRY_MODEL_IS_PUBLIC_IGNORED ;
45+ import static ml .comet .experiment .impl .resources .LogMessages .getString ;
2746
2847/**
2948 * The implementation of the {@link CometApi}.
@@ -91,6 +110,69 @@ public List<ExperimentMetadata> getAllExperiments(@NonNull String projectId) {
91110 ArrayList ::addAll );
92111 }
93112
113+ @ Override
114+ public ModelRegistryRecord registerModel (@ NonNull final Model model , @ NonNull final String experimentKey ) {
115+ // get list of experiment models
116+ List <ExperimentModelResponse > experimentModels = this .restApiClient
117+ .getExperimentModels (experimentKey )
118+ .map (ExperimentModelListResponse ::getModels )
119+ .blockingGet ();
120+
121+ // check that experiment has models registered
122+ if (experimentModels == null || experimentModels .size () == 0 ) {
123+ throw new ModelNotFoundException (getString (EXPERIMENT_HAS_NO_MODELS , experimentKey ));
124+ }
125+
126+ // check that experiment has our model in the list
127+ Optional <ExperimentModelResponse > details = experimentModels .stream ()
128+ .filter (modelResponse -> Objects .equals (modelResponse .getModelName (), model .getName ()))
129+ .findFirst ();
130+ if (!details .isPresent ()) {
131+ String names = experimentModels .stream ()
132+ .map (ExperimentModelResponse ::getModelName )
133+ .collect (Collectors .joining (", " ));
134+ throw new ModelNotFoundException (
135+ getString (FAILED_TO_FIND_EXPERIMENT_MODEL_BY_NAME , model .getName (), names ));
136+ }
137+
138+ // set model fields
139+ final RegistryModelImpl modelImpl = (RegistryModelImpl ) model ;
140+ modelImpl .setExperimentModelId (details .get ().getExperimentModelId ());
141+
142+ // check if model already registered in the experiment's workspace records
143+ Boolean modelInRegistry = this .restApiClient .getMetadata (experimentKey )
144+ .concatMap (experimentMetadataRest -> {
145+ modelImpl .setWorkspace (experimentMetadataRest .getWorkspaceName ());
146+ return this .restApiClient .getRegistryModelsForWorkspace (experimentMetadataRest .getWorkspaceName ());
147+ })
148+ .map (RegistryModelOverviewListResponse ::getRegistryModels )
149+ .flatMapObservable (Observable ::fromIterable )
150+ .any (registryModel -> Objects .equals (registryModel .getModelName (), model .getRegistryName ()))
151+ .blockingGet ();
152+
153+ ModelRegistryRecord registry ;
154+ if (modelInRegistry ) {
155+ // create new version of the model
156+ if (StringUtils .isNotBlank (modelImpl .getDescription ())) {
157+ this .logger .warn (getString (UPDATE_REGISTRY_MODEL_DESCRIPTION_IGNORED ));
158+ }
159+ if (modelImpl .getIsPublic () != null ) {
160+ this .logger .warn (getString (UPDATE_REGISTRY_MODEL_IS_PUBLIC_IGNORED ));
161+ }
162+ RegistryModelItemCreateRequest request = DataModelUtils .createRegistryModelItemCreateRequest (modelImpl );
163+ registry = this .restApiClient .createRegistryModelItem (request ).blockingGet ().toModelRegistry ();
164+ this .logger .info (getString (MODEL_VERSION_CREATED_IN_WORKSPACE ,
165+ model .getVersion (), model .getRegistryName (), model .getWorkspace ()));
166+ } else {
167+ // create model's registry record
168+ RegistryModelCreateRequest request = DataModelUtils .createRegistryModelCreateRequest (modelImpl );
169+ registry = this .restApiClient .createRegistryModel (request ).blockingGet ().toModelRegistry ();
170+ this .logger .info (getString (MODEL_REGISTERED_IN_WORKSPACE ,
171+ model .getRegistryName (), model .getVersion (), model .getWorkspace ()));
172+ }
173+ return registry ;
174+ }
175+
94176 /**
95177 * Release all resources hold by this instance, such as connection to the Comet server.
96178 *
@@ -117,6 +199,10 @@ void init() {
117199 this .restApiClient = new RestApiClient (this .connection );
118200 }
119201
202+ RestApiClient getRestApiClient () {
203+ return this .restApiClient ;
204+ }
205+
120206 /**
121207 * Returns builder to be used to properly create instance of this class.
122208 *
0 commit comments