Skip to content

Commit 3ba0abd

Browse files
authored
Merge pull request #100 from dataiku/feature/dss90-hints-api
diagnostics: add helper in DSSTrainedModelDetails
2 parents b09f950 + c17796a commit 3ba0abd

File tree

1 file changed

+100
-0
lines changed

1 file changed

+100
-0
lines changed

dataikuapi/dss/ml.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,50 @@ def get_algorithm_settings(self, algorithm_name):
253253

254254
return self.mltask_settings["modeling"][algorithm_name.lower()]
255255

256+
def get_diagnostics_settings(self):
257+
"""
258+
Gets the diagnostics settings for a mltask. This returns a reference to the
259+
diagnostics' settings, not a copy, so changes made to the returned object will be reflected when saving.
260+
261+
This method returns a dictionary of the settings with:
262+
- 'enabled': indicates if the diagnostics are enabled globally, if False, all diagnostics will be disabled
263+
- 'settings': a list of dict comprised of:
264+
- 'type': the diagnostic type
265+
- 'enabled': indicates if the diagnostic type is enabled, if False, all diagnostics of that type will be disabled
266+
267+
Please refer to the documentation for details on available diagnostics.
268+
269+
:return: A dict of diagnostics settings
270+
:rtype: dict
271+
"""
272+
return self.mltask_settings["diagnosticsSettings"]
273+
274+
def set_diagnostics_enabled(self, enabled):
275+
"""
276+
Globally enables or disables all diagnostics.
277+
278+
:param bool enabled: if the diagnostics should be enabled or not
279+
"""
280+
settings = self.get_diagnostics_settings()
281+
settings["enabled"] = enabled
282+
283+
def set_diagnostic_type_enabled(self, diagnostic_type, enabled):
284+
"""
285+
Enables or disables a diagnostic based on its type.
286+
287+
Please refer to the documentation for details on available diagnostics.
288+
289+
:param str diagnostic_type: Name (in capitals) of the diagnostic type.
290+
:param bool enabled: if the diagnostic should be enabled or not
291+
"""
292+
settings = self.get_diagnostics_settings()["settings"]
293+
diagnostic = [h for h in settings if h["type"] == diagnostic_type]
294+
if len(diagnostic) == 0:
295+
raise ValueError("Diagnostic type '{}' not found in settings".format(diagnostic_type))
296+
if len(diagnostic) > 1:
297+
raise ValueError("Should not happen: multiple diagnostic types '{}' found in settings".format(diagnostic_type))
298+
diagnostic[0]["enabled"] = enabled
299+
256300
def set_algorithm_enabled(self, algorithm_name, enabled):
257301
"""
258302
Enables or disables an algorithm based on its name.
@@ -543,6 +587,62 @@ def get_origin_analysis_trained_model(self):
543587
project_key=self.saved_model.project_key)
544588
return origin_ml_task.get_trained_model_details(fmi)
545589

590+
def get_diagnostics(self):
591+
"""
592+
Retrieves diagnostics computed for this trained model
593+
594+
:returns: list of diagnostics
595+
:rtype: list of type `dataikuapi.dss.ml.DSSMLDiagnostic`
596+
"""
597+
diagnostics = self.details.get("trainDiagnostics", {})
598+
return [DSSMLDiagnostic(d) for d in diagnostics.get("diagnostics", [])]
599+
600+
601+
class DSSMLDiagnostic(object):
602+
"""
603+
Object that represents a computed Diagnostic on a trained model
604+
605+
Do not create this object directly, use :meth:`DSSTrainedModelDetails.get_diagnostics()` instead
606+
"""
607+
608+
def __init__(self, data):
609+
self._internal_dict = data
610+
611+
def get_raw(self):
612+
"""
613+
Gets the raw dictionary of the diagnostic
614+
615+
:rtype: dict
616+
"""
617+
return self._internal_dict
618+
619+
def get_type(self):
620+
"""
621+
Returns the base Diagnostic type
622+
:rtype: str
623+
"""
624+
return self._internal_dict["type"]
625+
626+
def get_type_pretty(self):
627+
"""
628+
Returns the Diagnostic type as displayed in the UI
629+
:rtype: str
630+
"""
631+
return self._internal_dict["displayableType"]
632+
633+
def get_message(self):
634+
"""
635+
Returns the message as displayed in the UI
636+
:rtype: str
637+
"""
638+
return self._internal_dict["message"]
639+
640+
def __repr__(self):
641+
return "{cls}(type={type}, message={msg})".format(cls=self.__class__.__name__,
642+
type=self._internal_dict["type"],
643+
msg=self._internal_dict["message"])
644+
645+
546646
class DSSTreeNode(object):
547647
def __init__(self, tree, i):
548648
self.tree = tree

0 commit comments

Comments
 (0)