Skip to content

Commit ea3903d

Browse files
committed
[MNT] Fix styles
1 parent a8f9dcc commit ea3903d

File tree

1 file changed

+17
-22
lines changed

1 file changed

+17
-22
lines changed

netneurotools/stats.py

Lines changed: 17 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77

88
import numpy as np
99
from tqdm import tqdm
10+
from itertools import combinations
1011
from scipy import optimize, spatial, special, stats as sstats
1112
from scipy.stats.stats import _chk2_asarray
1213
from sklearn.utils.validation import check_random_state
14+
from sklearn.linear_model import LinearRegression
1315

1416
from . import utils
1517

@@ -845,8 +847,11 @@ def get_dominance_stats(X, y, use_adjusted_r_sq=True, verbose=False):
845847
846848
"""
847849

848-
from itertools import combinations
849-
from sklearn.linear_model import LinearRegression
850+
# this helps to remove one element from a tuple
851+
def remove_ret(tpl, elem):
852+
lst = list(tpl)
853+
lst.remove(elem)
854+
return tuple(lst)
850855

851856
# sklearn linear regression wrapper
852857
def get_reg_r_sq(X, y):
@@ -863,24 +868,21 @@ def get_reg_r_sq(X, y):
863868
else:
864869
return r_squared
865870

866-
def complete_model_rsquare(X, y):
867-
return get_reg_r_sq(X, y)
868-
869871
# generate all predictor combinations in list (num of predictors) of lists
870872
n_predictor = X.shape[-1]
871873
# n_comb_len_group = n_predictor - 1
872874
predictor_combs = [list(combinations(range(n_predictor), i))
873875
for i in range(1, n_predictor + 1)]
874876
if verbose:
875877
print(f"[Dominance analysis] Generated \
876-
{len([_ for i in predictor_combs for _ in i])} combinations")
878+
{len([v for i in predictor_combs for v in i])} combinations")
877879

878880
# get all r_sq's
879-
model_r_sq = dict([])
881+
model_r_sq = dict()
880882
for len_group in tqdm(predictor_combs, desc='num-of-predictor loop',
881-
disable=~verbose):
883+
disable=not verbose):
882884
for idx_tuple in tqdm(len_group, desc='insider loop',
883-
disable=~verbose):
885+
disable=not verbose):
884886
r_sq = get_reg_r_sq(X[:, idx_tuple], y)
885887
model_r_sq[idx_tuple] = r_sq
886888
if verbose:
@@ -896,24 +898,16 @@ def complete_model_rsquare(X, y):
896898
individual_dominance = np.array(individual_dominance).reshape(1, -1)
897899
model_metrics["individual_dominance"] = individual_dominance
898900

899-
def remove_ret(tpl, elem):
900-
lst = list(tpl)
901-
lst.remove(elem)
902-
return tuple(lst)
903-
904901
# partial dominance
905-
partial_dominance = [[] for _ in range(n_predictor - 1)]
902+
partial_dominance = [[]] * (n_predictor - 1)
906903
for i_len in range(n_predictor - 1):
907904
i_len_combs = list(combinations(range(n_predictor), i_len + 2))
908-
# print(i_len_combs)
909905
for j_node in range(n_predictor):
910-
j_node_sel = [_ for _ in i_len_combs if j_node in _]
906+
j_node_sel = [v for v in i_len_combs if j_node in v]
911907
reduced_list = [remove_ret(comb, j_node) for comb in j_node_sel]
912-
# print(j_node, j_node_sel, reduced_list)
913908
diff_values = [
914-
model_r_sq[j_node_sel[_]] - model_r_sq[reduced_list[_]]
915-
for _ in range(len(reduced_list))]
916-
# print(diff_values)
909+
model_r_sq[j_node_sel[i]] - model_r_sq[reduced_list[i]]
910+
for i in range(len(reduced_list))]
917911
partial_dominance[i_len].append(np.mean(diff_values))
918912

919913
# save partial dominance
@@ -924,7 +918,8 @@ def remove_ret(tpl, elem):
924918
np.r_[individual_dominance, partial_dominance], axis=0)
925919
# test and save total dominance
926920
assert np.allclose(total_dominance.sum(),
927-
model_r_sq[tuple(range(n_predictor))])
921+
model_r_sq[tuple(range(n_predictor))]), \
922+
"Sum of total dominance is not equal to full r square!"
928923
model_metrics["total_dominance"] = total_dominance
929924
# save full r^2
930925
model_metrics["full_r_sq"] = model_r_sq[tuple(range(n_predictor))]

0 commit comments

Comments
 (0)