From 8abc2374c2e5f8c6ef8d0120a51cbf5fb8376889 Mon Sep 17 00:00:00 2001 From: RozeQz Date: Fri, 25 Oct 2024 12:42:50 +0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20=20Remove=20usage=20of=20interna?= =?UTF-8?q?l=20sklearn.utils=20api,=20replace=20with=20equivalent?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- sklift/utils/__init__.py | 4 ++-- sklift/utils/utils.py | 17 +++++++++++++++++ sklift/viz/base.py | 3 +-- 3 files changed, 20 insertions(+), 4 deletions(-) diff --git a/sklift/utils/__init__.py b/sklift/utils/__init__.py index 5cebd97..bbdab5b 100644 --- a/sklift/utils/__init__.py +++ b/sklift/utils/__init__.py @@ -1,3 +1,3 @@ -from .utils import check_is_binary +from .utils import check_is_binary, check_matplotlib_support -__all__ = ['check_is_binary'] \ No newline at end of file +__all__ = ['check_is_binary', 'check_matplotlib_support'] diff --git a/sklift/utils/utils.py b/sklift/utils/utils.py index 9aa2690..c43a75d 100644 --- a/sklift/utils/utils.py +++ b/sklift/utils/utils.py @@ -11,3 +11,20 @@ def check_is_binary(array): raise ValueError(f"Input array is not binary. " f"Array should contain only int or float binary values 0 (or 0.) and 1 (or 1.). " f"Got values {np.unique(array)}.") + + +def check_matplotlib_support(caller_name): + """Raise ImportError with detailed error message if mpl is not installed. + Plot utilities like any of the Display's plotting functions should lazily import + matplotlib and call this helper before any computation. + + Args: + caller_name (str): The name of the caller that requires matplotlib. + """ + try: + import matplotlib # noqa + except ImportError as e: + raise ImportError( + "{} requires matplotlib. You can install matplotlib with " + "`pip install matplotlib`".format(caller_name) + ) from e diff --git a/sklift/viz/base.py b/sklift/viz/base.py index 75bf0dd..3a110f5 100644 --- a/sklift/viz/base.py +++ b/sklift/viz/base.py @@ -1,9 +1,8 @@ import matplotlib.pyplot as plt import numpy as np from sklearn.utils.validation import check_consistent_length -from sklearn.utils import check_matplotlib_support -from ..utils import check_is_binary +from ..utils import check_is_binary, check_matplotlib_support from ..metrics import ( uplift_curve, perfect_uplift_curve, uplift_auc_score, qini_curve, perfect_qini_curve, qini_auc_score,