Skip to content

Commit 90f3d7d

Browse files
authored
Merge pull request #235 from dataiku/chore/dss110-sc-85322-support-passing-managed-folder-as-id
Support passing the managed folder as an id or a Folder in setup_mlflow
2 parents 53ddf7c + 529d8e5 commit 90f3d7d

File tree

2 files changed

+27
-9
lines changed

2 files changed

+27
-9
lines changed

dataikuapi/dss/project.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,8 +1612,10 @@ def setup_mlflow(self, managed_folder, host=None):
16121612
"""
16131613
Setup the dss-plugin for MLflow
16141614
1615-
:param object managed_folder: a :class:`dataikuapi.dss.DSSManagedFolder` where MLflow artifacts should be stored.
1616-
:param str host: setup a custom host if the backend used is not DSS
1615+
:param object managed_folder: the managed folder where MLflow artifacts should be stored.
1616+
Can be either a managed folder id as a string,
1617+
a :class:`dataikuapi.dss.DSSManagedFolder`, or a :class:`dataiku.Folder`
1618+
:param str host: setup a custom host if the backend used is not DSS.
16171619
"""
16181620
return MLflowHandle(client=self.client, project=self, managed_folder=managed_folder, host=host)
16191621

dataikuapi/dss_plugin_mlflow/utils.py

Lines changed: 23 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,30 +62,46 @@ def __init__(self, client, project, managed_folder, host=None):
6262
"DSS_MLFLOW_INTERNAL_TICKET": self.client.internal_ticket
6363
})
6464

65-
if not isinstance(managed_folder, DSSManagedFolder):
66-
raise TypeError('managed_folder must a DSSManagedFolder.')
67-
6865
if not client._session.verify:
6966
self.mlflow_env.update({"MLFLOW_TRACKING_INSECURE_TLS": "true"})
7067
elif isinstance(client._session.verify, str):
7168
self.mlflow_env.update({"MLFLOW_TRACKING_SERVER_CERT_PATH": client._session.verify})
7269

70+
mf_full_id = None
71+
if isinstance(managed_folder, DSSManagedFolder):
72+
mf_full_id = managed_folder.project.project_key + "." + managed_folder.id
73+
elif isinstance(managed_folder, str):
74+
mf_full_id = managed_folder
75+
else:
76+
try:
77+
from dataiku import Folder
78+
if isinstance(managed_folder, Folder):
79+
mf_full_id = managed_folder.name
80+
except ImportError:
81+
pass
82+
83+
if not mf_full_id:
84+
raise TypeError('Type of managed_folder must be "str", "DSSManagedFolder" or "dataiku.Folder".')
85+
86+
if not "." in mf_full_id:
87+
mf_full_id = self.project_key + "." + mf_full_id
88+
89+
mf_project = mf_full_id.split(".")[0]
90+
mf_id = mf_full_id.split(".")[1]
7391

74-
mf_project = managed_folder.project.project_key
75-
mf_id = managed_folder.id
7692
try:
7793
client.get_project(mf_project).get_managed_folder(mf_id).get_definition()
7894
except DataikuException as e:
7995
if "NotFoundException" in str(e):
80-
logging.error('The managed folder "%s" does not exist, please create it in your project flow before running this command.' % (mf_id))
96+
logging.error('The managed folder "%s" does not exist, please create it in your project flow before running this command.' % (mf_full_id))
8197
raise
8298

8399
# Set host, tracking URI, project key and managed_folder_id
84100
self.mlflow_env.update({
85101
"DSS_MLFLOW_PROJECTKEY": self.project_key,
86102
"MLFLOW_TRACKING_URI": self.client.host + "/dip/publicapi" if host is None else host,
87103
"DSS_MLFLOW_HOST": self.client.host,
88-
"DSS_MLFLOW_MANAGED_FOLDER_ID": mf_project + "." + mf_id
104+
"DSS_MLFLOW_MANAGED_FOLDER_ID": mf_full_id
89105
})
90106

91107
os.environ.update(self.mlflow_env)

0 commit comments

Comments
 (0)