Skip to content

Commit 471c3e7

Browse files
authored
[ENH] KCD and Bregman conditional 2-sample tests (#21)
Closes: #15 Follow-up to #19 Changes proposed in this pull request: - Adds `kcd` and `bremgan` test along w/ unit-tests and documentation update - As a result of #19, the code is entirely self-contained and leverages the kernel functions that are shared w/ the kci test. ## Before submitting <!-- Please complete this checklist BEFORE submitting your PR to speed along the review process. --> - [ ] I've read and followed all steps in the [Making a pull request](https://github.com/py-why/pywhy-stats/blob/main/CONTRIBUTING.md#making-a-pull-request) section of the `CONTRIBUTING` docs. - [ ] I've updated or added any relevant docstrings following the syntax described in the [Writing docstrings](https://github.com/py-why/pywhy-stats/blob/main/CONTRIBUTING.md#writing-docstrings) section of the `CONTRIBUTING` docs. - [ ] If this PR fixes a bug, I've added a test that will fail without my fix. - [ ] If this PR adds a new feature, I've added tests that sufficiently cover my new functionality. - [ ] I have added a changelog entry succintly describing the change in this PR in the [whats_new](https://github.com/py-why/pywhy-stats/blob/main/docs/whats_new/) relevant version document. ## After submitting <!-- Please complete this checklist AFTER submitting your PR to speed along the review process. --> - [ ] All GitHub Actions jobs for my pull request have passed. --------- Signed-off-by: Adam Li <[email protected]>
1 parent 256d8c9 commit 471c3e7

File tree

11 files changed

+912
-2
lines changed

11 files changed

+912
-2
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Minimally, pywhy-stats requires:
2525
* Python (>=3.8)
2626
* numpy
2727
* scipy
28+
* scikit-learn
2829

2930
## User Installation
3031

doc/api.rst

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,3 +64,17 @@ of many data analysis procedures.
6464
fisherz
6565
kci
6666
power_divergence
67+
68+
(Conditional) K-Sample Testing
69+
==============================
70+
71+
Testing for invariances among conditional distributions is a core part
72+
of many data analysis procedures. Currently, we only support conditional
73+
2-sample testing among two distributions.
74+
75+
.. currentmodule:: pywhy_stats.conditional_ksample
76+
.. autosummary::
77+
:toctree: generated/
78+
79+
bregman
80+
kcd

doc/conditional_independence.rst

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,35 @@ indices of the distribution, one can convert the CD test:
177177
:math:`P_{i=j}(y|x) =? P_{i=k}(y|x)` into the CI test :math:`P(y|x,i) = P(y|x)`, which can
178178
be tested with the Chi-square CI tests.
179179

180+
:mod:`pywhy_stats.conditional_ksample.kcd` Kernel-Approaches
181+
------------------------------------------------------------
182+
Kernel-based tests are attractive since they are semi-parametric and use kernel-based ideas
183+
that have been shown to be robust in the machine-learning field. The Kernel CD test is a test
184+
that computes a test statistic from kernels of the data and uses a weighted permutation testing
185+
based on the estimated propensity scores to generate samples from the null distribution
186+
:footcite:`Park2021conditional`, which are then used to estimate a pvalue.
187+
188+
.. currentmodule:: pywhy_stats.conditional_ksample
189+
.. autosummary::
190+
:toctree: generated/
191+
192+
kcd
193+
194+
195+
:mod:`pywhy_stats.conditional_ksample.bregman` Bregman-Divergences
196+
------------------------------------------------------------------
197+
The Bregman CD test is a divergence-based test
198+
that computes a test statistic from estimated Von-Neumann divergences of the data and uses a
199+
weighted permutation testing based on the estimated propensity scores to generate samples from the null distribution
200+
:footcite:`Yu2020Bregman`, which are then used to estimate a pvalue.
201+
202+
203+
.. currentmodule:: pywhy_stats.conditional_ksample
204+
.. autosummary::
205+
:toctree: generated/
206+
207+
bregman
208+
180209
==========
181210
References
182211
==========

doc/whats_new/v0.1.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ Changelog
2929
- |Feature| Implement partial correlation test :func:`pywhy_stats.independence.fisherz`, by `Adam Li`_ (:pr:`7`)
3030
- |Feature| Add (un)conditional kernel independence test by `Patrick Blöbaum`_, co-authored by `Adam Li`_ (:pr:`14`)
3131
- |Feature| Add categorical independence tests by `Adam Li`_, (:pr:`18`)
32-
32+
- |Feature| Add conditional kernel and Bregman discrepancy tests, `pywhy_stats.kcd` and `pywhy_stats.bregman` by `Adam Li`_ (:pr:`21`)
3333

3434
Code and Documentation Contributors
3535
-----------------------------------
@@ -38,4 +38,4 @@ Thanks to everyone who has contributed to the maintenance and improvement of
3838
the project since version inception, including:
3939

4040
* `Adam Li`_
41-
41+
* `Patrick Blöbaum`_

pywhy_stats/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from ._version import __version__ # noqa: F401
22
from .api import Methods, independence_test
3+
from .conditional_ksample import bregman, kcd
34
from .independence import fisherz, kci, power_divergence
45
from .pvalue_result import PValueResult
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from . import bregman, kcd
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
from typing import Callable, Optional
2+
3+
import numpy as np
4+
from joblib import Parallel, delayed
5+
from numpy.typing import ArrayLike
6+
from sklearn.base import BaseEstimator
7+
from sklearn.linear_model import LogisticRegression
8+
9+
from pywhy_stats.kernel_utils import _default_regularization
10+
11+
12+
def _preprocess_propensity_data(
13+
group_ind: ArrayLike,
14+
propensity_model: Optional[BaseEstimator],
15+
propensity_weights: Optional[ArrayLike],
16+
):
17+
if group_ind.ndim != 1:
18+
raise RuntimeError("group_ind must be a 1d array.")
19+
if len(np.unique(group_ind)) != 2:
20+
raise RuntimeError(
21+
f"There should only be two groups. Found {len(np.unique(group_ind))} groups."
22+
)
23+
if propensity_model is not None and propensity_weights is not None:
24+
raise ValueError(
25+
"Both propensity model and propensity estimates are specified. Only one is allowed."
26+
)
27+
if propensity_weights is not None:
28+
if propensity_weights.shape[0] != len(group_ind):
29+
raise ValueError(
30+
f"There are {propensity_weights.shape[0]} pre-defined estimates, while "
31+
f"there are {len(group_ind)} samples."
32+
)
33+
if propensity_weights.shape[1] != len(np.unique(group_ind.squeeze())):
34+
raise ValueError(
35+
f"There are {propensity_weights.shape[1]} group pre-defined estimates, while "
36+
f"there are {len(np.unique(group_ind))} unique groups."
37+
)
38+
39+
40+
def _compute_propensity_scores(
41+
group_ind: ArrayLike,
42+
propensity_model: Optional[BaseEstimator] = None,
43+
propensity_weights: Optional[ArrayLike] = None,
44+
n_jobs: Optional[int] = None,
45+
random_state: Optional[int] = None,
46+
**kwargs,
47+
):
48+
if propensity_model is None:
49+
K: ArrayLike = kwargs.get("K")
50+
51+
# compute a default penalty term if using a kernel matrix
52+
# C is the inverse of the regularization parameter
53+
if K.shape[0] == K.shape[1]:
54+
# default regularization is 1 / (2 * K)
55+
propensity_penalty_ = _default_regularization(K)
56+
C = 1 / (2 * propensity_penalty_)
57+
else:
58+
# defaults to no regularization
59+
propensity_penalty_ = 0.0
60+
C = 1.0
61+
62+
# default model is logistic regression
63+
propensity_model_ = LogisticRegression(
64+
penalty="l2",
65+
n_jobs=n_jobs,
66+
warm_start=True,
67+
solver="lbfgs",
68+
random_state=random_state,
69+
C=C,
70+
)
71+
else:
72+
propensity_model_ = propensity_model
73+
74+
# either use pre-defined propensity weights, or estimate them
75+
if propensity_weights is None:
76+
K = kwargs.get("K")
77+
# fit and then obtain the probabilities of treatment
78+
# for each sample (i.e. the propensity scores)
79+
propensity_weights = propensity_model_.fit(K, group_ind.ravel()).predict_proba(K)[:, 1]
80+
else:
81+
propensity_weights = propensity_weights[:, 1]
82+
return propensity_weights
83+
84+
85+
def compute_null(
86+
func: Callable,
87+
e_hat: ArrayLike,
88+
X: ArrayLike,
89+
Y: ArrayLike,
90+
null_reps: int = 1000,
91+
n_jobs=None,
92+
seed=None,
93+
**kwargs,
94+
) -> ArrayLike:
95+
"""Estimate null distribution using propensity weights.
96+
97+
Parameters
98+
----------
99+
func : Callable
100+
The function to compute the test statistic.
101+
e_hat : Array-like of shape (n_samples,)
102+
The predicted propensity score for ``group_ind == 1``.
103+
X : Array-Like of shape (n_samples, n_features_x)
104+
The X (covariates) array.
105+
Y : Array-Like of shape (n_samples, n_features_y)
106+
The Y (outcomes) array.
107+
null_reps : int, optional
108+
Number of times to sample null, by default 1000.
109+
n_jobs : int, optional
110+
Number of jobs to run in parallel, by default None.
111+
seed : int, optional
112+
Random generator, or random seed, by default None.
113+
114+
Returns
115+
-------
116+
null_dist : Array-like of shape (n_samples,)
117+
The null distribution of test statistics.
118+
"""
119+
rng = np.random.default_rng(seed)
120+
n_samps = X.shape[0]
121+
122+
# compute the test statistic on the conditionally permuted
123+
# dataset, where each group label is resampled for each sample
124+
# according to its propensity score
125+
null_dist = Parallel(n_jobs=n_jobs)(
126+
[
127+
delayed(func)(X, Y, group_ind=rng.binomial(1, e_hat, size=n_samps), **kwargs)
128+
for _ in range(null_reps)
129+
]
130+
)
131+
return np.asarray(null_dist)

0 commit comments

Comments
 (0)