|
1 | 1 | import numbers |
| 2 | +from functools import partial |
2 | 3 |
|
3 | 4 | import numpy as np |
4 | 5 | from numpy.random import RandomState |
| 6 | +from scipy.stats import ttest_1samp, wilcoxon |
| 7 | + |
| 8 | +from hidimstat.statistical_tools.nadeau_bengio_ttest import nadeau_bengio_ttest |
5 | 9 |
|
6 | 10 |
|
7 | 11 | def _check_vim_predict_method(method): |
@@ -33,6 +37,37 @@ def _check_vim_predict_method(method): |
33 | 37 | ) |
34 | 38 |
|
35 | 39 |
|
| 40 | +def get_fitted_attributes(cls): |
| 41 | + """ |
| 42 | + Get all attributes from a class that end with a single underscore |
| 43 | + and doesn't start with one underscore. |
| 44 | +
|
| 45 | + Parameters |
| 46 | + ---------- |
| 47 | + cls : class |
| 48 | + The class to inspect for attributes. |
| 49 | +
|
| 50 | + Returns |
| 51 | + ------- |
| 52 | + list |
| 53 | + A list of attribute names that end with a single underscore but not double underscore. |
| 54 | + """ |
| 55 | + # Get all attributes and methods of the class |
| 56 | + all_attributes = dir(cls) |
| 57 | + |
| 58 | + # Filter out attributes that start with an underscore |
| 59 | + filtered_attributes = [attr for attr in all_attributes if not attr.startswith("_")] |
| 60 | + |
| 61 | + # Filter out attributes that do not end with a single underscore |
| 62 | + result = [ |
| 63 | + attr |
| 64 | + for attr in filtered_attributes |
| 65 | + if attr.endswith("_") and not attr.endswith("__") |
| 66 | + ] |
| 67 | + |
| 68 | + return result |
| 69 | + |
| 70 | + |
36 | 71 | def check_random_state(seed): |
37 | 72 | """ |
38 | 73 | Modified version of sklearn's check_random_state using np.random.Generator. |
@@ -105,3 +140,56 @@ def seed_estimator(estimator, random_state=None): |
105 | 140 | setattr(value, "random_state", RandomState(rng.bit_generator)) |
106 | 141 |
|
107 | 142 | return estimator |
| 143 | + |
| 144 | + |
| 145 | +def check_statistical_test(statistical_test, test_frac=None): |
| 146 | + """ |
| 147 | + Validates and returns a test statistic function. |
| 148 | +
|
| 149 | + Parameters |
| 150 | + ---------- |
| 151 | + statisticcal_test : str or callable |
| 152 | + If str, must be either 'ttest' or 'wilcoxon'. |
| 153 | + If callable, must be a function that can be used as a test statistic. |
| 154 | + test_frac : float, optional |
| 155 | + The fraction of data used for testing in the Nadeau-Bengio t-test. |
| 156 | +
|
| 157 | + Returns |
| 158 | + ------- |
| 159 | + callable |
| 160 | + A function that can be used as a test statistic. |
| 161 | + For string inputs, returns a partial function of either ttest_1samp or wilcoxon. |
| 162 | + For callable inputs, returns the input function. |
| 163 | +
|
| 164 | + Raises |
| 165 | + ------ |
| 166 | + ValueError |
| 167 | + If test is a string but not one of the supported test names ('ttest' or 'wilcoxon'). |
| 168 | + ValueError |
| 169 | + If test is neither a string nor a callable. |
| 170 | + """ |
| 171 | + if isinstance(statistical_test, str): |
| 172 | + if statistical_test == "ttest": |
| 173 | + return partial(ttest_1samp, popmean=0, alternative="greater", axis=1) |
| 174 | + elif statistical_test == "wilcoxon": |
| 175 | + return partial(wilcoxon, alternative="greater", axis=1) |
| 176 | + elif statistical_test == "nb-ttest": |
| 177 | + return partial( |
| 178 | + nadeau_bengio_ttest, |
| 179 | + popmean=0, |
| 180 | + test_frac=test_frac, |
| 181 | + alternative="greater", |
| 182 | + axis=1, |
| 183 | + ) |
| 184 | + else: |
| 185 | + raise ValueError(f"the test '{statistical_test}' is not supported") |
| 186 | + elif callable(statistical_test): |
| 187 | + return statistical_test |
| 188 | + else: |
| 189 | + raise ValueError( |
| 190 | + f"Unsupported value for 'statistical_test'." |
| 191 | + f"The provided argument was '{statistical_test}'. " |
| 192 | + f"Please choose from the following valid options: " |
| 193 | + f"string values ('ttest', 'wilcoxon', 'nb-ttest') " |
| 194 | + f"or a custom callable function with a `scipy.stats` API-compatible signature." |
| 195 | + ) |
0 commit comments