Skip to content

Commit 264c77d

Browse files
authored
Merge pull request #97 from liuzhenqi77/dominance-stats
[ENH] Adds dominance stats function
2 parents a1beaed + ea3903d commit 264c77d

File tree

3 files changed

+143
-0
lines changed

3 files changed

+143
-0
lines changed

docs/api.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ Python Reference API
111111
permtest_1samp
112112
permtest_rel
113113
permtest_pearsonr
114+
get_dominance_stats
114115

115116
.. _ref_metrics:
116117

netneurotools/stats.py

Lines changed: 141 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,12 @@
66
import warnings
77

88
import numpy as np
9+
from tqdm import tqdm
10+
from itertools import combinations
911
from scipy import optimize, spatial, special, stats as sstats
1012
from scipy.stats.stats import _chk2_asarray
1113
from sklearn.utils.validation import check_random_state
14+
from sklearn.linear_model import LinearRegression
1215

1316
from . import utils
1417

@@ -784,3 +787,141 @@ def gen_spinsamples(coords, hemiid, n_rotate=1000, check_duplicates=True,
784787
return spinsamples, cost
785788

786789
return spinsamples
790+
791+
792+
def get_dominance_stats(X, y, use_adjusted_r_sq=True, verbose=False):
793+
"""
794+
Returns the dominance analysis statistics for multilinear regression.
795+
796+
This is a rewritten & simplified version of [DA1]_. It is briefly
797+
tested against the original package, but still in early stages.
798+
Please feel free to report any bugs.
799+
800+
Warning: Still work-in-progress. Parameters might change!
801+
802+
Parameters
803+
----------
804+
X : (N, M) array_like
805+
Input data
806+
y : (N,) array_like
807+
Target values
808+
use_adjusted_r_sq : bool, optional
809+
Whether to use adjusted r squares. Default: True
810+
verbose : bool, optional
811+
Whether to print debug messages. Default: False
812+
813+
Returns
814+
-------
815+
model_metrics : dict
816+
The dominance metrics, currently containing `individual_dominance`,
817+
`partial_dominance`, `total_dominance`, and `full_r_sq`.
818+
model_r_sq : dict
819+
Contains all model r squares
820+
821+
Notes
822+
-----
823+
Example usage
824+
825+
.. code:: python
826+
827+
from netneurotools.stats import get_dominance_stats
828+
from sklearn.datasets import load_boston
829+
X, y = load_boston(return_X_y=True)
830+
model_metrics, model_r_sq = get_dominance_stats(X, y)
831+
832+
To compare with [DA1]_, use `use_adjusted_r_sq=False`
833+
834+
.. code:: python
835+
836+
from dominance_analysis import Dominance_Datasets
837+
from dominance_analysis import Dominance
838+
boston_dataset=Dominance_Datasets.get_boston()
839+
dominance_regression=Dominance(data=boston_dataset,
840+
target='House_Price',objective=1)
841+
incr_variable_rsquare=dominance_regression.incremental_rsquare()
842+
dominance_regression.dominance_stats()
843+
844+
References
845+
----------
846+
.. [DA1] https://github.com/dominance-analysis/dominance-analysis
847+
848+
"""
849+
850+
# this helps to remove one element from a tuple
851+
def remove_ret(tpl, elem):
852+
lst = list(tpl)
853+
lst.remove(elem)
854+
return tuple(lst)
855+
856+
# sklearn linear regression wrapper
857+
def get_reg_r_sq(X, y):
858+
lin_reg = LinearRegression()
859+
lin_reg.fit(X, y)
860+
yhat = lin_reg.predict(X)
861+
SS_Residual = sum((y - yhat) ** 2)
862+
SS_Total = sum((y - np.mean(y)) ** 2)
863+
r_squared = 1 - (float(SS_Residual)) / SS_Total
864+
adjusted_r_squared = 1 - (1 - r_squared) * \
865+
(len(y) - 1) / (len(y) - X.shape[1] - 1)
866+
if use_adjusted_r_sq:
867+
return adjusted_r_squared
868+
else:
869+
return r_squared
870+
871+
# generate all predictor combinations in list (num of predictors) of lists
872+
n_predictor = X.shape[-1]
873+
# n_comb_len_group = n_predictor - 1
874+
predictor_combs = [list(combinations(range(n_predictor), i))
875+
for i in range(1, n_predictor + 1)]
876+
if verbose:
877+
print(f"[Dominance analysis] Generated \
878+
{len([v for i in predictor_combs for v in i])} combinations")
879+
880+
# get all r_sq's
881+
model_r_sq = dict()
882+
for len_group in tqdm(predictor_combs, desc='num-of-predictor loop',
883+
disable=not verbose):
884+
for idx_tuple in tqdm(len_group, desc='insider loop',
885+
disable=not verbose):
886+
r_sq = get_reg_r_sq(X[:, idx_tuple], y)
887+
model_r_sq[idx_tuple] = r_sq
888+
if verbose:
889+
print(f"[Dominance analysis] Acquired {len(model_r_sq)} r^2's")
890+
891+
# getting all model metrics
892+
model_metrics = dict([])
893+
894+
# individual dominance
895+
individual_dominance = []
896+
for i_pred in range(n_predictor):
897+
individual_dominance.append(model_r_sq[(i_pred,)])
898+
individual_dominance = np.array(individual_dominance).reshape(1, -1)
899+
model_metrics["individual_dominance"] = individual_dominance
900+
901+
# partial dominance
902+
partial_dominance = [[]] * (n_predictor - 1)
903+
for i_len in range(n_predictor - 1):
904+
i_len_combs = list(combinations(range(n_predictor), i_len + 2))
905+
for j_node in range(n_predictor):
906+
j_node_sel = [v for v in i_len_combs if j_node in v]
907+
reduced_list = [remove_ret(comb, j_node) for comb in j_node_sel]
908+
diff_values = [
909+
model_r_sq[j_node_sel[i]] - model_r_sq[reduced_list[i]]
910+
for i in range(len(reduced_list))]
911+
partial_dominance[i_len].append(np.mean(diff_values))
912+
913+
# save partial dominance
914+
partial_dominance = np.array(partial_dominance)
915+
model_metrics["partial_dominance"] = partial_dominance
916+
# get total dominance
917+
total_dominance = np.mean(
918+
np.r_[individual_dominance, partial_dominance], axis=0)
919+
# test and save total dominance
920+
assert np.allclose(total_dominance.sum(),
921+
model_r_sq[tuple(range(n_predictor))]), \
922+
"Sum of total dominance is not equal to full r square!"
923+
model_metrics["total_dominance"] = total_dominance
924+
# save full r^2
925+
model_metrics["full_r_sq"] = model_r_sq[tuple(range(n_predictor))]
926+
927+
return model_metrics, model_r_sq

requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,4 @@ nilearn
55
numpy>=1.16
66
scikit-learn
77
scipy>=1.4.0
8+
tqdm

0 commit comments

Comments
 (0)