Skip to content

Commit 8998e60

Browse files
committed
Refactor check_fitted utility
1 parent a15ce1b commit 8998e60

File tree

3 files changed

+35
-11
lines changed

3 files changed

+35
-11
lines changed

ylearn/uplift/_model.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,11 @@
11
import numpy as np
22
import pandas as pd
33

4+
from ylearn.utils import check_fitted as _check_fitted
45
from ._metric import get_gain, get_qini, get_cumlift, auuc_score, qini_score
56
from ._plot import plot_gain, plot_qini, plot_cumlift
67

78

8-
def _check_fitted(fn):
9-
def _exec(obj, *args, **kwargs):
10-
assert isinstance(obj, UpliftModel)
11-
if obj.lift_ is None:
12-
raise ValueError(f'fit {type(obj).__name__} before call {fn.__name__}() please.')
13-
14-
return fn(obj, *args, **kwargs)
15-
16-
return _exec
17-
18-
199
class UpliftModel(object):
2010
def __init__(self):
2111
# fitted
@@ -57,6 +47,7 @@ def fit(self, df_lift, outcome='y', treatment='x', true_effect=None, treat=1, co
5747
)
5848
self.random_ = random
5949
self.lift_ = df_lift.copy()
50+
self._is_fitted = True
6051

6152
return self
6253

ylearn/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@
1313
from ._common import to_df, context, is_notebook, view_pydot
1414
from ._common import to_list, join_list
1515
from ._common import to_snake_case, to_camel_case, drop_none
16+
from ._common import check_fitted, check_fitted_

ylearn/utils/_common.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import inspect
33
import warnings
44
from collections import OrderedDict
5+
from functools import partial
56

67
import numpy as np
78
import pandas as pd
@@ -262,3 +263,34 @@ def _get_array(cols):
262263
S = map(_get_array, S)
263264

264265
return tuple(S)
266+
267+
268+
def check_fitted(fn, attr_name='_is_fitted', check=None, msg=None):
269+
assert callable(fn)
270+
271+
sig = inspect.signature(fn)
272+
assert 'self' in sig.parameters.keys()
273+
274+
def check_and_call(obj, *args, **kwargs):
275+
if callable(check):
276+
fitted = check(obj)
277+
else:
278+
fitted_tag = getattr(obj, attr_name, None)
279+
if check is not None:
280+
fitted = not (fitted_tag is check)
281+
else:
282+
fitted = fitted_tag
283+
284+
if not fitted:
285+
if msg is not None:
286+
raise ValueError(msg)
287+
else:
288+
raise ValueError(f'{type(obj).__name__} is not fitted.')
289+
290+
return fn(obj, *args, **kwargs)
291+
292+
return check_and_call
293+
294+
295+
def check_fitted_(attr_name='_is_fitted', check=None, msg=None):
296+
return partial(check_fitted, attr_name=attr_name, check=check, msg=msg)

0 commit comments

Comments
 (0)