-
Notifications
You must be signed in to change notification settings - Fork 12
[API 2]: CFI, PFI, LOCO #372
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 79 commits
df93c78
ccb60ed
7c827ad
82d61e6
28593e4
cabfb63
7d7fd7d
b958cc7
1f97d60
db96bb6
d656f17
0493b6f
9c54e1b
7bf75e4
b3cd78a
7825490
084ad24
7379ec1
02ae5ba
1e91c65
46b8fa5
58a57f8
03b919a
c4ea731
43d3f99
5fd99b0
3c52789
c93f14c
aa583d5
01cbc44
b1c5f40
1a47330
1ef69a6
d00566b
83ae849
a3cd681
b3f336a
a364a93
6dc8d67
afd03cb
3fa9d01
1202426
75d6578
ae3dfa9
035f4c8
84ff550
251584c
1c15c5b
9414368
8a8587f
631ee83
39ebce1
971da61
c7adec9
a3c7906
f4c8ce4
c983eab
f91ac6c
38c6bc7
0bb880e
a92ea83
2166f2c
4cc3598
0e6b929
4389796
9a640b1
431d9a6
e7bbb30
e003345
1973bcc
57ca067
e4d686c
a9cd709
75e81d7
9ab658a
329bf43
9f488d6
17f8d6e
d33205c
cfc12d0
7a4f44a
d14b835
87e2029
c61bb44
b0c4ec0
359118a
1deae93
911f11c
bc4ee65
e93b97f
e03266e
eca0c0c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,7 +1,11 @@ | ||
| import numbers | ||
| from functools import partial | ||
|
|
||
| import numpy as np | ||
| from numpy.random import RandomState | ||
| from scipy.stats import ttest_1samp, wilcoxon | ||
|
|
||
| from hidimstat.statistical_tools.nadeau_bengio_ttest import nadeau_bengio_ttest | ||
|
|
||
|
|
||
| def _check_vim_predict_method(method): | ||
|
|
@@ -33,6 +37,37 @@ def _check_vim_predict_method(method): | |
| ) | ||
|
|
||
|
|
||
| def get_fitted_attributes(cls): | ||
| """ | ||
| Get all attributes from a class that end with a single underscore | ||
| and doesn't start with one underscore. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| cls : class | ||
| The class to inspect for attributes. | ||
|
|
||
| Returns | ||
| ------- | ||
| list | ||
| A list of attribute names that end with a single underscore but not double underscore. | ||
| """ | ||
| # Get all attributes and methods of the class | ||
| all_attributes = dir(cls) | ||
|
|
||
| # Filter out attributes that start with an underscore | ||
| filtered_attributes = [attr for attr in all_attributes if not attr.startswith("_")] | ||
|
|
||
| # Filter out attributes that do not end with a single underscore | ||
| result = [ | ||
| attr | ||
| for attr in filtered_attributes | ||
| if attr.endswith("_") and not attr.endswith("__") | ||
| ] | ||
|
|
||
| return result | ||
|
|
||
|
|
||
| def check_random_state(seed): | ||
| """ | ||
| Modified version of sklearn's check_random_state using np.random.Generator. | ||
|
|
@@ -105,3 +140,54 @@ def seed_estimator(estimator, random_state=None): | |
| setattr(value, "random_state", RandomState(rng.bit_generator)) | ||
|
|
||
| return estimator | ||
|
|
||
|
|
||
| def check_statistical_test(statistical_test): | ||
| """ | ||
| Validates and returns a test statistic function. | ||
|
|
||
| Parameters | ||
| ---------- | ||
| test : str or callable | ||
| If str, must be either 'ttest' or 'wilcoxon'. | ||
jpaillard marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| If callable, must be a function that can be used as a test statistic. | ||
|
|
||
| Returns | ||
| ------- | ||
| callable | ||
| A function that can be used as a test statistic. | ||
| For string inputs, returns a partial function of either ttest_1samp or wilcoxon. | ||
| For callable inputs, returns the input function. | ||
|
|
||
| Raises | ||
| ------ | ||
| ValueError | ||
| If test is a string but not one of the supported test names ('ttest' or 'wilcoxon'). | ||
| ValueError | ||
| If test is neither a string nor a callable. | ||
| """ | ||
| if isinstance(statistical_test, str): | ||
| if statistical_test == "ttest": | ||
| return partial(ttest_1samp, popmean=0, alternative="greater", axis=1) | ||
| elif statistical_test == "wilcoxon": | ||
| return partial(wilcoxon, alternative="greater", axis=1) | ||
| elif statistical_test == "nb-ttest": | ||
| return partial( | ||
| nadeau_bengio_ttest, | ||
| popmean=0, | ||
| test_frac=0.1 / 0.9, | ||
jpaillard marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| alternative="greater", | ||
| axis=1, | ||
| ) | ||
| else: | ||
| raise ValueError(f"the test '{statistical_test}' is not supported") | ||
| elif callable(statistical_test): | ||
| return statistical_test | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported value for 'statistical_test'." | ||
| f"The provided argument was '{statistical_test}'. " | ||
| f"Please choose from the following valid options: " | ||
| f"string values ('ttest', 'wilcoxon', 'nb-ttest') " | ||
| f"or a custom callable function with a `scipy.stats` API-compatible signature." | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.