Skip to content

Commit cf53630

Browse files
committed
Support passing the managed folder as an id or a Folder in setup_mlflow
1 parent dc5e5dd commit cf53630

File tree

2 files changed

+20
-9
lines changed

2 files changed

+20
-9
lines changed

dataikuapi/dss/project.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1612,8 +1612,10 @@ def setup_mlflow(self, managed_folder, host=None):
16121612
"""
16131613
Setup the dss-plugin for MLflow
16141614
1615-
:param object managed_folder: a :class:`dataikuapi.dss.DSSManagedFolder` where MLflow artifacts should be stored.
1616-
:param str host: setup a custom host if the backend used is not DSS
1615+
:param object managed_folder: the managed folder where MLflow artifacts should be stored.
1616+
Can be either a managed folder id as a string,
1617+
a :class:`dataikuapi.dss.DSSManagedFolder`, or a :class:`dataiku.Folder`
1618+
:param str host: setup a custom host if the backend used is not DSS.
16171619
"""
16181620
return MLflowHandle(client=self.client, project=self, managed_folder=managed_folder, host=host)
16191621

dataikuapi/dss_plugin_mlflow/utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -62,30 +62,39 @@ def __init__(self, client, project, managed_folder, host=None):
6262
"DSS_MLFLOW_INTERNAL_TICKET": self.client.internal_ticket
6363
})
6464

65-
if not isinstance(managed_folder, DSSManagedFolder):
66-
raise TypeError('managed_folder must a DSSManagedFolder.')
67-
6865
if not client._session.verify:
6966
self.mlflow_env.update({"MLFLOW_TRACKING_INSECURE_TLS": "true"})
7067
elif isinstance(client._session.verify, str):
7168
self.mlflow_env.update({"MLFLOW_TRACKING_SERVER_CERT_PATH": client._session.verify})
7269

70+
mf_full_id = managed_folder
71+
if isinstance(managed_folder, DSSManagedFolder):
72+
mf_full_id = managed_folder.project.project_key + "." + managed_folder.id
73+
elif hasattr(managed_folder, 'name'): # True if dataiku.Folder
74+
mf_full_id = managed_folder.name
75+
76+
if not isinstance(mf_full_id, str):
77+
raise TypeError('Type of managed_folder must be "str", "DSSManagedFolder" or "dataiku.Folder".')
78+
79+
if not "." in mf_full_id:
80+
mf_full_id = self.project_key + "." + mf_full_id
81+
82+
mf_project = mf_full_id.split(".")[0]
83+
mf_id = mf_full_id.split(".")[1]
7384

74-
mf_project = managed_folder.project.project_key
75-
mf_id = managed_folder.id
7685
try:
7786
client.get_project(mf_project).get_managed_folder(mf_id).get_definition()
7887
except DataikuException as e:
7988
if "NotFoundException" in str(e):
80-
logging.error('The managed folder "%s" does not exist, please create it in your project flow before running this command.' % (mf_id))
89+
logging.error('The managed folder "%s" does not exist, please create it in your project flow before running this command.' % (mf_full_id))
8190
raise
8291

8392
# Set host, tracking URI, project key and managed_folder_id
8493
self.mlflow_env.update({
8594
"DSS_MLFLOW_PROJECTKEY": self.project_key,
8695
"MLFLOW_TRACKING_URI": self.client.host + "/dip/publicapi" if host is None else host,
8796
"DSS_MLFLOW_HOST": self.client.host,
88-
"DSS_MLFLOW_MANAGED_FOLDER_ID": mf_project + "." + mf_id
97+
"DSS_MLFLOW_MANAGED_FOLDER_ID": mf_full_id
8998
})
9099

91100
os.environ.update(self.mlflow_env)

0 commit comments

Comments
 (0)