diff --git a/sklift/utils/__init__.py b/sklift/utils/__init__.py index 5cebd97..bbdab5b 100644 --- a/sklift/utils/__init__.py +++ b/sklift/utils/__init__.py @@ -1,3 +1,3 @@ -from .utils import check_is_binary +from .utils import check_is_binary, check_matplotlib_support -__all__ = ['check_is_binary'] \ No newline at end of file +__all__ = ['check_is_binary', 'check_matplotlib_support'] diff --git a/sklift/utils/utils.py b/sklift/utils/utils.py index 9aa2690..c43a75d 100644 --- a/sklift/utils/utils.py +++ b/sklift/utils/utils.py @@ -11,3 +11,20 @@ def check_is_binary(array): raise ValueError(f"Input array is not binary. " f"Array should contain only int or float binary values 0 (or 0.) and 1 (or 1.). " f"Got values {np.unique(array)}.") + + +def check_matplotlib_support(caller_name): + """Raise ImportError with detailed error message if mpl is not installed. + Plot utilities like any of the Display's plotting functions should lazily import + matplotlib and call this helper before any computation. + + Args: + caller_name (str): The name of the caller that requires matplotlib. + """ + try: + import matplotlib # noqa + except ImportError as e: + raise ImportError( + "{} requires matplotlib. You can install matplotlib with " + "`pip install matplotlib`".format(caller_name) + ) from e diff --git a/sklift/viz/base.py b/sklift/viz/base.py index 75bf0dd..3a110f5 100644 --- a/sklift/viz/base.py +++ b/sklift/viz/base.py @@ -1,9 +1,8 @@ import matplotlib.pyplot as plt import numpy as np from sklearn.utils.validation import check_consistent_length -from sklearn.utils import check_matplotlib_support -from ..utils import check_is_binary +from ..utils import check_is_binary, check_matplotlib_support from ..metrics import ( uplift_curve, perfect_uplift_curve, uplift_auc_score, qini_curve, perfect_qini_curve, qini_auc_score,