Skip to content

Commit b8d086f

Browse files
Merge pull request #5 from bystrogenomics/feature/standardize
Standardized some methods
2 parents acf0671 + 6618a44 commit b8d086f

File tree

3 files changed

+59
-30
lines changed

3 files changed

+59
-30
lines changed

covtest/methods/hypothesis_identity.py

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
from . import _srivastava_2005 as s2005
3838
from . import _tylers as tyler
39+
from .utils import validate_data_matrix
3940

4041

4142
def _ahmad_2015_stat(x: np.ndarray) -> float:
@@ -60,20 +61,21 @@ def _ahmad_2015_stat(x: np.ndarray) -> float:
6061
return nrow * (c3 / ncol - 2.0 * c1 / ncol + 1.0)
6162

6263

63-
def ahmad2015_identity(x, Sigma="identity"):
64+
def ahmad2015_identity(X, Sigma="identity"):
6465
"""
6566
Ahmad & von Rosen (2015) test of covariance matrix structure,
66-
when a data matrix x (n x p) is supplied.
67+
when a data matrix X (n x p) is supplied.
6768
"""
68-
n, p = x.shape
69+
X = validate_data_matrix(X)
70+
n, p = X.shape
6971

7072
if Sigma == "identity":
71-
x_ = x
73+
X_ = X
7274
else:
7375
u_s, d_s, _ = svd(Sigma)
74-
x_ = x @ solve(u_s @ np.diag(np.sqrt(d_s)), np.eye(p))
76+
X_ = X @ solve(u_s @ np.diag(np.sqrt(d_s)), np.eye(p))
7577

76-
statistic = _ahmad_2015_stat(x_)
78+
statistic = _ahmad_2015_stat(X_)
7779
parameter = {"Mean": 0, "Variance": 4 * (2 / (p / n + 1))}
7880
pval = 2 * (
7981
1
@@ -128,12 +130,12 @@ def _ledoit_wolf_stat(data):
128130

129131

130132
# Checked
131-
def ledoit_wolf_identity(data):
133+
def ledoit_wolf_identity(X):
132134
"""Perform the Ledoit–Wolf test for identity covariance.
133135
134136
Parameters
135137
----------
136-
data : array-like of shape (n_samples, n_features)
138+
X : array-like of shape (n_samples, n_features)
137139
The data matrix, where rows correspond to samples and
138140
columns to variables.
139141
@@ -147,8 +149,9 @@ def ledoit_wolf_identity(data):
147149
- ``'p_value'`` : float
148150
The p-value from the chi-square distribution.
149151
"""
150-
n, p = data.shape
151-
W = _ledoit_wolf_stat(data)
152+
X = validate_data_matrix(X)
153+
n, p = X.shape
154+
W = _ledoit_wolf_stat(X)
152155
degree_of_freedom = p * (p + 1) / 2
153156
stat = n * p / 2 * W
154157
p_value = 1 - stats.chi2.cdf(stat, degree_of_freedom)
@@ -192,12 +195,12 @@ def _nagao_stat(data):
192195

193196

194197
# Checked
195-
def nagao_identity(data):
198+
def nagao_identity(X):
196199
"""Perform Nagao’s test for identity covariance.
197200
198201
Parameters
199202
----------
200-
data : array-like of shape (n_samples, n_features)
203+
X : array-like of shape (n_samples, n_features)
201204
The data matrix, where rows correspond to samples and columns to variables.
202205
203206
Returns
@@ -223,8 +226,9 @@ def nagao_identity(data):
223226
The null hypothesis is :math:`\\Sigma = I_p`, where :math:`\\Sigma` is the covariance
224227
matrix and :math:`I_p` is the identity matrix.
225228
"""
226-
n, p = data.shape
227-
V = _nagao_stat(data)
229+
X = validate_data_matrix(X)
230+
n, p = X.shape
231+
V = _nagao_stat(X)
228232
degree_of_freedom = p * (p + 1) / 2
229233
stat = n * p / 2 * V
230234
p_value = 1 - stats.chi2.cdf(stat, degree_of_freedom)
@@ -234,6 +238,7 @@ def nagao_identity(data):
234238

235239
# Checked
236240
def srivastava_2005_identity(X):
241+
X = validate_data_matrix(X)
237242
n = X.shape[0]
238243
S = np.cov(X.T)
239244
T_1 = s2005.T_1_stat(S, n)
@@ -250,6 +255,7 @@ def tyler_identity(X, unknown_mean=False, method="tr"):
250255
One-sample test H0: Sigma = I_p.
251256
If unknown_mean=True, uses robust location-adjusted version.
252257
"""
258+
X = validate_data_matrix(X)
253259
n, p = X.shape
254260
if unknown_mean:
255261
mu_hat = tyler.robust_location(X)
@@ -309,10 +315,11 @@ def _fisher_2012_stat_(n, p, S_):
309315
return (n / np.sqrt(8 * (c**2 + 12 * c + 8))) * (ahat4 - 2 * ahat2 + 1)
310316

311317

312-
def fisher_single_sample(x, Sigma="identity"):
313-
p = x.shape[1]
314-
n = x.shape[0]
315-
S = np.cov(x, rowvar=False)
318+
def fisher_single_sample(X, Sigma="identity"):
319+
X = validate_data_matrix(X)
320+
p = X.shape[1]
321+
n = X.shape[0]
322+
S = np.cov(X, rowvar=False)
316323

317324
if Sigma == "identity":
318325
S_ = S
@@ -345,10 +352,11 @@ def _srivastava2011_(n, p, S_):
345352
return n * (term1 - term2 + 1) / 2
346353

347354

348-
def srivastava2011_single_sample(x, Sigma="identity"):
349-
p = x.shape[1]
350-
n = x.shape[0]
351-
S = np.cov(x, rowvar=False)
355+
def srivastava2011_single_sample(X, Sigma="identity"):
356+
X = validate_data_matrix(X)
357+
p = X.shape[1]
358+
n = X.shape[0]
359+
S = np.cov(X, rowvar=False)
352360

353361
if Sigma == "identity":
354362
S_ = S

covtest/methods/hypothesis_spherical.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from . import _hallin2006 as hallin2006
55
from . import _srivastava_2005 as s2005
6+
from .utils import validate_data_matrix
67

78

89
def bartlett_sphericity_test(X):
@@ -23,7 +24,7 @@ def bartlett_sphericity_test(X):
2324
p_value : float
2425
p-value for the test.
2526
"""
26-
27+
X = validate_data_matrix(X)
2728
n, p = X.shape
2829

2930
# Compute correlation matrix
@@ -82,12 +83,12 @@ def _john_stat(data):
8283

8384

8485
# Checked
85-
def john_sphericity(data):
86+
def john_sphericity(X):
8687
"""Perform John's sphericity hypothesis test.
8788
8889
Parameters
8990
----------
90-
data : array-like of shape (n_samples, n_features)
91+
X : array-like of shape (n_samples, n_features)
9192
The data matrix, where rows correspond to samples and columns
9293
to variables.
9394
@@ -123,8 +124,9 @@ def john_sphericity(data):
123124
Biometrika, 58(1), 123–127.
124125
https://doi.org/10.1093/biomet/58.1.123
125126
"""
126-
n, p = data.shape
127-
U = _john_stat(data)
127+
X = validate_data_matrix(X)
128+
n, p = X.shape
129+
U = _john_stat(X)
128130
degree_of_freedom = p * (p + 1) / 2 - 1
129131
stat = U * n * p / 2
130132
p_value = 1 - stats.chi2.cdf(stat, degree_of_freedom)
@@ -134,6 +136,7 @@ def john_sphericity(data):
134136

135137
# Checked
136138
def srivastava_2005_sphericity(X):
139+
X = validate_data_matrix(X)
137140
n = X.shape[0]
138141
S = np.cov(X.T)
139142
T_1 = s2005.T_1_stat(S, n)
@@ -172,7 +175,7 @@ def sk_test(X):
172175
Returns:
173176
dict(statistic=Q, z=z, p_value=p)
174177
"""
175-
X = np.asarray(X)
178+
X = validate_data_matrix(X)
176179
n, p = X.shape
177180
if n < 4:
178181
raise ValueError("Need n >= 4 for the leave-one-out estimator.")
@@ -263,7 +266,7 @@ def muirhead_sphericity_lrt(
263266
raise ValueError("If S is provided, also provide n (sample size).")
264267

265268
if X is not None:
266-
X = np.asarray(X, dtype=float)
269+
X = validate_data_matrix(X)
267270
if X.ndim != 2:
268271
raise ValueError("X must be 2D.")
269272
n, p = X.shape
@@ -354,7 +357,7 @@ def czz_sphericity_test(X, center=False):
354357
sum_{i,j,k all distinct} <X_i,X_j><X_i,X_k> = sum_R2 - sumsq_off
355358
sum_{i,j,k,l all distinct} <X_i,X_j><X_k,X_l> = s_off^2 - 4*sum_R2 + 2*sumsq_off
356359
"""
357-
X = np.asarray(X, dtype=float)
360+
X = validate_data_matrix(X)
358361
if X.ndim != 2:
359362
raise ValueError("X must be a 2D array (n, p).")
360363
n, p = X.shape
@@ -404,6 +407,7 @@ def hallin_rank_sphericity_test(X, method="wilcoxon"):
404407
"""
405408
Van der Waerden (normal-score) rank-based test for sphericity.
406409
"""
410+
X = validate_data_matrix(X)
407411
if method == "wilcoxon":
408412
n, k = X.shape
409413
U, d = hallin2006._center_and_scale(X)

covtest/methods/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from sklearn.utils import check_array
2+
3+
def validate_data_matrix(X):
4+
"""
5+
Validate that X is a 2D array of numeric values.
6+
7+
Parameters
8+
----------
9+
X : array-like
10+
Input data.
11+
12+
Returns
13+
-------
14+
X_validated : ndarray
15+
Validated 2D array.
16+
"""
17+
return check_array(X, ensure_2d=True, allow_nd=False, force_all_finite=True, dtype="numeric")

0 commit comments

Comments
 (0)