|
6 | 6 | import warnings |
7 | 7 |
|
8 | 8 | import numpy as np |
| 9 | +from tqdm import tqdm |
| 10 | +from itertools import combinations |
9 | 11 | from scipy import optimize, spatial, special, stats as sstats |
10 | 12 | from scipy.stats.stats import _chk2_asarray |
11 | 13 | from sklearn.utils.validation import check_random_state |
| 14 | +from sklearn.linear_model import LinearRegression |
12 | 15 |
|
13 | 16 | from . import utils |
14 | 17 |
|
@@ -784,3 +787,141 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True, |
784 | 787 | return spinsamples, cost |
785 | 788 |
|
786 | 789 | return spinsamples |
| 790 | + |
| 791 | + |
| 792 | +def get_dominance_stats(X, y, use_adjusted_r_sq=True, verbose=False): |
| 793 | + """ |
| 794 | + Returns the dominance analysis statistics for multilinear regression. |
| 795 | +
|
| 796 | + This is a rewritten & simplified version of [DA1]_. It is briefly |
| 797 | + tested against the original package, but still in early stages. |
| 798 | + Please feel free to report any bugs. |
| 799 | +
|
| 800 | + Warning: Still work-in-progress. Parameters might change! |
| 801 | +
|
| 802 | + Parameters |
| 803 | + ---------- |
| 804 | + X : (N, M) array_like |
| 805 | + Input data |
| 806 | + y : (N,) array_like |
| 807 | + Target values |
| 808 | + use_adjusted_r_sq : bool, optional |
| 809 | + Whether to use adjusted r squares. Default: True |
| 810 | + verbose : bool, optional |
| 811 | + Whether to print debug messages. Default: False |
| 812 | +
|
| 813 | + Returns |
| 814 | + ------- |
| 815 | + model_metrics : dict |
| 816 | + The dominance metrics, currently containing `individual_dominance`, |
| 817 | + `partial_dominance`, `total_dominance`, and `full_r_sq`. |
| 818 | + model_r_sq : dict |
| 819 | + Contains all model r squares |
| 820 | +
|
| 821 | + Notes |
| 822 | + ----- |
| 823 | + Example usage |
| 824 | +
|
| 825 | + .. code:: python |
| 826 | +
|
| 827 | + from netneurotools.stats import get_dominance_stats |
| 828 | + from sklearn.datasets import load_boston |
| 829 | + X, y = load_boston(return_X_y=True) |
| 830 | + model_metrics, model_r_sq = get_dominance_stats(X, y) |
| 831 | +
|
| 832 | + To compare with [DA1]_, use `use_adjusted_r_sq=False` |
| 833 | +
|
| 834 | + .. code:: python |
| 835 | +
|
| 836 | + from dominance_analysis import Dominance_Datasets |
| 837 | + from dominance_analysis import Dominance |
| 838 | + boston_dataset=Dominance_Datasets.get_boston() |
| 839 | + dominance_regression=Dominance(data=boston_dataset, |
| 840 | + target='House_Price',objective=1) |
| 841 | + incr_variable_rsquare=dominance_regression.incremental_rsquare() |
| 842 | + dominance_regression.dominance_stats() |
| 843 | +
|
| 844 | + References |
| 845 | + ---------- |
| 846 | + .. [DA1] https://github.com/dominance-analysis/dominance-analysis |
| 847 | +
|
| 848 | + """ |
| 849 | + |
| 850 | + # this helps to remove one element from a tuple |
| 851 | + def remove_ret(tpl, elem): |
| 852 | + lst = list(tpl) |
| 853 | + lst.remove(elem) |
| 854 | + return tuple(lst) |
| 855 | + |
| 856 | + # sklearn linear regression wrapper |
| 857 | + def get_reg_r_sq(X, y): |
| 858 | + lin_reg = LinearRegression() |
| 859 | + lin_reg.fit(X, y) |
| 860 | + yhat = lin_reg.predict(X) |
| 861 | + SS_Residual = sum((y - yhat) ** 2) |
| 862 | + SS_Total = sum((y - np.mean(y)) ** 2) |
| 863 | + r_squared = 1 - (float(SS_Residual)) / SS_Total |
| 864 | + adjusted_r_squared = 1 - (1 - r_squared) * \ |
| 865 | + (len(y) - 1) / (len(y) - X.shape[1] - 1) |
| 866 | + if use_adjusted_r_sq: |
| 867 | + return adjusted_r_squared |
| 868 | + else: |
| 869 | + return r_squared |
| 870 | + |
| 871 | + # generate all predictor combinations in list (num of predictors) of lists |
| 872 | + n_predictor = X.shape[-1] |
| 873 | + # n_comb_len_group = n_predictor - 1 |
| 874 | + predictor_combs = [list(combinations(range(n_predictor), i)) |
| 875 | + for i in range(1, n_predictor + 1)] |
| 876 | + if verbose: |
| 877 | + print(f"[Dominance analysis] Generated \ |
| 878 | + {len([v for i in predictor_combs for v in i])} combinations") |
| 879 | + |
| 880 | + # get all r_sq's |
| 881 | + model_r_sq = dict() |
| 882 | + for len_group in tqdm(predictor_combs, desc='num-of-predictor loop', |
| 883 | + disable=not verbose): |
| 884 | + for idx_tuple in tqdm(len_group, desc='insider loop', |
| 885 | + disable=not verbose): |
| 886 | + r_sq = get_reg_r_sq(X[:, idx_tuple], y) |
| 887 | + model_r_sq[idx_tuple] = r_sq |
| 888 | + if verbose: |
| 889 | + print(f"[Dominance analysis] Acquired {len(model_r_sq)} r^2's") |
| 890 | + |
| 891 | + # getting all model metrics |
| 892 | + model_metrics = dict([]) |
| 893 | + |
| 894 | + # individual dominance |
| 895 | + individual_dominance = [] |
| 896 | + for i_pred in range(n_predictor): |
| 897 | + individual_dominance.append(model_r_sq[(i_pred,)]) |
| 898 | + individual_dominance = np.array(individual_dominance).reshape(1, -1) |
| 899 | + model_metrics["individual_dominance"] = individual_dominance |
| 900 | + |
| 901 | + # partial dominance |
| 902 | + partial_dominance = [[]] * (n_predictor - 1) |
| 903 | + for i_len in range(n_predictor - 1): |
| 904 | + i_len_combs = list(combinations(range(n_predictor), i_len + 2)) |
| 905 | + for j_node in range(n_predictor): |
| 906 | + j_node_sel = [v for v in i_len_combs if j_node in v] |
| 907 | + reduced_list = [remove_ret(comb, j_node) for comb in j_node_sel] |
| 908 | + diff_values = [ |
| 909 | + model_r_sq[j_node_sel[i]] - model_r_sq[reduced_list[i]] |
| 910 | + for i in range(len(reduced_list))] |
| 911 | + partial_dominance[i_len].append(np.mean(diff_values)) |
| 912 | + |
| 913 | + # save partial dominance |
| 914 | + partial_dominance = np.array(partial_dominance) |
| 915 | + model_metrics["partial_dominance"] = partial_dominance |
| 916 | + # get total dominance |
| 917 | + total_dominance = np.mean( |
| 918 | + np.r_[individual_dominance, partial_dominance], axis=0) |
| 919 | + # test and save total dominance |
| 920 | + assert np.allclose(total_dominance.sum(), |
| 921 | + model_r_sq[tuple(range(n_predictor))]), \ |
| 922 | + "Sum of total dominance is not equal to full r square!" |
| 923 | + model_metrics["total_dominance"] = total_dominance |
| 924 | + # save full r^2 |
| 925 | + model_metrics["full_r_sq"] = model_r_sq[tuple(range(n_predictor))] |
| 926 | + |
| 927 | + return model_metrics, model_r_sq |
0 commit comments