Skip to content

Commit 19e26ea

Browse files
authored
Added a function to retrieve a model from a comparison (#196)
* Added a function to retrieve a model from a comparison * PR fixes * Added pattern matching * Added exception
1 parent 0fd738a commit 19e26ea

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

dataikuapi/dss/project.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import time, warnings, sys, os.path as osp
2+
import re
23
from .dataset import DSSDataset, DSSDatasetListItem, DSSManagedDatasetCreationHelper
34
from .modelcomparison import DSSModelComparison
45
from .jupyternotebook import DSSJupyterNotebook, DSSJupyterNotebookListItem
@@ -865,6 +866,36 @@ def create_model_comparison(self, name, prediction_type):
865866
mec_id = res['id']
866867
return DSSModelComparison(self.client, self.project_key, mec_id)
867868

869+
def get_from_full_id(self, full_id):
870+
"""
871+
Retrieves a Saved Model from the flow, a Lab Model from an Analysis or a Model Evaluation from a Model Evaluation Store) using its full id.
872+
873+
:param string full_id: the full id of the item to retrieve
874+
875+
:returns: A handle on the Saved Model, the Model Evaluation or the Lab Model
876+
:rtype: :class:`dataikuapi.dss.savedmodel.DSSSavedModel`
877+
:rtype: :class:`dataikuapi.dss.modelevaluationstore.DSSModelEvaluation`
878+
:rtype: :class:`dataikuapi.dss.ml.DSSTrainedPredictionModelDetails`
879+
"""
880+
881+
saved_model_pattern = re.compile("^S-(\\w+)-(\\w+)-(\\w+)(?:-part-(\\w+)-(v?\\d+))?$\\Z")
882+
analysis_model_pattern = re.compile("^A-(\\w+)-(\\w+)-(\\w+)-(s[0-9]+)-(pp[0-9]+(?:-part-(\\w+)|-base)?)-(m[0-9]+)$\\Z")
883+
model_evaluation_pattern = re.compile("^ME-(\\w+)-(\\w+)-(\\w+)$\\Z")
884+
885+
if saved_model_pattern.match(full_id):
886+
return self.get_saved_model(full_id)
887+
elif model_evaluation_pattern.match(full_id):
888+
mes_id = full_id.split('-')[2]
889+
evaluation_id = full_id.split('-')[3]
890+
mes = self.get_model_evaluation_store(mes_id)
891+
return mes.get_model_evaluation(evaluation_id)
892+
elif analysis_model_pattern.match(full_id):
893+
analysis_id = full_id.split('-')[2]
894+
task_id = full_id.split('-')[3]
895+
return self.get_ml_task(analysis_id, task_id).get_trained_model_details(full_id)
896+
897+
raise ValueError("{} is not a valid full model id or full model evaluation id.".format(full_id))
898+
868899
########################################################
869900
# Jobs
870901
########################################################

0 commit comments

Comments
 (0)