|
74 | 74 | d0crt_linear.fit_importance(X, y) |
75 | 75 | pval_dcrt_linear = d0crt_linear.pvalues_ |
76 | 76 |
|
77 | | -d0crt_non_linear = D0CRT(estimator=clone(non_linear_model), screening_threshold=None) |
| 77 | +d0crt_non_linear = D0CRT( |
| 78 | + estimator=clone(non_linear_model), |
| 79 | + screening_threshold=None, |
| 80 | +) |
78 | 81 | d0crt_non_linear.fit_importance(X, y) |
79 | 82 | pval_dcrt_non_linear = d0crt_non_linear.pvalues_ |
80 | 83 |
|
|
97 | 100 | linear_model_.fit(X[train], y[train]) |
98 | 101 |
|
99 | 102 | vim_linear = LOCO( |
100 | | - estimator=linear_model_, loss=log_loss, method="predict_proba", n_jobs=2 |
| 103 | + estimator=linear_model_, |
| 104 | + loss=log_loss, |
| 105 | + method="predict_proba", |
| 106 | + n_jobs=2, |
101 | 107 | ) |
102 | 108 | vim_non_linear = LOCO( |
103 | 109 | estimator=non_linear_model_, |
|
108 | 114 | vim_linear.fit(X[train], y[train]) |
109 | 115 | vim_non_linear.fit(X[train], y[train]) |
110 | 116 |
|
111 | | - importances_linear.append(vim_linear.importance(X[test], y[test])["importance"]) |
| 117 | + importances_linear.append( |
| 118 | + vim_linear.importance(X[test], y[test])["importance"], |
| 119 | + ) |
112 | 120 | importances_non_linear.append( |
113 | 121 | vim_non_linear.importance(X[test], y[test])["importance"] |
114 | 122 | ) |
|
118 | 126 | # To select variables using LOCO, we compute the p-values using a t-test over the |
119 | 127 | # importance scores. |
120 | 128 |
|
121 | | -_, pval_linear = ttest_1samp(importances_linear, 0, axis=0, alternative="greater") |
| 129 | +_, pval_linear = ttest_1samp( |
| 130 | + importances_linear, |
| 131 | + 0, |
| 132 | + axis=0, |
| 133 | + alternative="greater", |
| 134 | +) |
122 | 135 | _, pval_non_linear = ttest_1samp( |
123 | 136 | importances_non_linear, 0, axis=0, alternative="greater" |
124 | 137 | ) |
125 | 138 |
|
126 | 139 | df_pval = pd.DataFrame( |
127 | 140 | { |
128 | 141 | "pval": np.hstack( |
129 | | - [pval_dcrt_linear, pval_dcrt_non_linear, pval_linear, pval_non_linear] |
| 142 | + [ |
| 143 | + pval_dcrt_linear, |
| 144 | + pval_dcrt_non_linear, |
| 145 | + pval_linear, |
| 146 | + pval_non_linear, |
| 147 | + ] |
130 | 148 | ), |
131 | 149 | "method": ["d0CRT-linear"] * 2 |
132 | 150 | + ["d0CRT-non-linear"] * 2 |
|
152 | 170 | ) |
153 | 171 | ax.set_xlabel("-$\\log_{10}(pval)$") |
154 | 172 | ax.axvline( |
155 | | - -np.log10(0.05), color="k", lw=3, linestyle="--", label="-$\\log_{10}(0.05)$" |
| 173 | + -np.log10(0.05), |
| 174 | + color="k", |
| 175 | + lw=3, |
| 176 | + linestyle="--", |
| 177 | + label="-$\\log_{10}(0.05)$", |
156 | 178 | ) |
157 | 179 | ax.legend() |
158 | 180 | plt.show() |
|
0 commit comments