Skip to content

Commit 1938c19

Browse files
adam2392bloebp
authored andcommitted
Address comments
Signed-off-by: Adam Li <[email protected]>
1 parent f9a0b07 commit 1938c19

File tree

2 files changed

+53
-13
lines changed

2 files changed

+53
-13
lines changed

pywhy_stats/power_divergence.py

Lines changed: 51 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
import logging
3030
from typing import Optional
31+
from warnings import warn
3132

3233
import numpy as np
3334
from numpy.typing import ArrayLike
@@ -38,7 +39,11 @@
3839

3940

4041
def ind(
41-
X: ArrayLike, Y: ArrayLike, method: str = "cressie-read", num_categories_allowed: int = 10
42+
X: ArrayLike,
43+
Y: ArrayLike,
44+
method: str = "cressie-read",
45+
num_categories_allowed: int = 10,
46+
on_error: str = "raise",
4247
) -> PValueResult:
4348
"""Perform an independence test using power divergence test.
4449
@@ -65,6 +70,11 @@ def ind(
6570
num_categories_allowed : int
6671
The maximum number of categories allowed in the input variables. Default
6772
of 10 is chosen to error out on large number of categories.
73+
on_error : str
74+
What to do when there are not enough samples in the data, where there are 0 samples
75+
in a cell of the contingency table. If 'raise', then
76+
raise an error. If 'warn', then log a warning and skip the test.
77+
If 'ignore', then ignore the warning and skip the test.
6878
6979
Returns
7080
-------
@@ -79,7 +89,12 @@ def ind(
7989
"""
8090
X, Y, _ = _preprocess_inputs(X=X, Y=Y, Z=None)
8191
return _power_divergence(
82-
X=X, Y=Y, Z=None, method=method, num_categories_allowed=num_categories_allowed
92+
X=X,
93+
Y=Y,
94+
Z=None,
95+
method=method,
96+
num_categories_allowed=num_categories_allowed,
97+
on_error=on_error,
8398
)
8499

85100

@@ -89,6 +104,7 @@ def condind(
89104
condition_on: ArrayLike,
90105
method: str = "cressie-read",
91106
num_categories_allowed: int = 10,
107+
on_error: str = "raise",
92108
) -> PValueResult:
93109
"""Perform an independence test using power divergence test.
94110
@@ -117,6 +133,11 @@ def condind(
117133
num_categories_allowed : int
118134
The maximum number of categories allowed in the input variables. Default
119135
of 10 is chosen to error out on large number of categories.
136+
on_error : str
137+
What to do when there are not enough samples in the data, where there are 0 samples
138+
in a cell of the contingency table. If 'raise', then
139+
raise an error. If 'warn', then log a warning and skip the test.
140+
If 'ignore', then ignore the warning and skip the test.
120141
121142
Returns
122143
-------
@@ -127,7 +148,12 @@ def condind(
127148
"""
128149
X, Y, condition_on = _preprocess_inputs(X=X, Y=Y, Z=condition_on)
129150
return _power_divergence(
130-
X=X, Y=Y, Z=condition_on, method=method, num_categories_allowed=num_categories_allowed
151+
X=X,
152+
Y=Y,
153+
Z=condition_on,
154+
method=method,
155+
num_categories_allowed=num_categories_allowed,
156+
on_error=on_error,
131157
)
132158

133159

@@ -208,6 +234,7 @@ def _power_divergence(
208234
method: str = "cressie-read",
209235
num_categories_allowed: int = 10,
210236
correction: bool = True,
237+
on_error: str = "raise",
211238
) -> PValueResult:
212239
"""Compute the Cressie-Read power divergence statistic.
213240
@@ -298,8 +325,8 @@ def _power_divergence(
298325
)
299326
n_samples_req = 10 * dof_check
300327
if n_samples < n_samples_req:
301-
raise RuntimeError(
302-
f"Not enough samples. {n_samples} is too small. Need {n_samples_req}."
328+
warn(
329+
f"Not enough samples. {n_samples} is probably too small. Should have {n_samples_req}."
303330
)
304331

305332
# XXX: currently we just leverage pandas to do the grouping. This is not
@@ -324,12 +351,25 @@ def _power_divergence(
324351
chi += c
325352
dof += d
326353
except ValueError:
327-
# If one of the values is 0 in the 2x2 table.
328-
if isinstance(z_state, str):
329-
logging.info(f"Skipping the test X \u27C2 Y | Z={z_state}. Not enough samples")
330-
else:
331-
z_str = ", ".join([f"{var}={state}" for var, state in zip(Z_columns, z_state)])
332-
logging.info(f"Skipping the test X \u27C2 Y | {z_str}. Not enough samples")
354+
if on_error == "raise":
355+
raise RuntimeError(
356+
"Not enough samples in the data, such that there is a 0 sample contingency table"
357+
)
358+
elif on_error == "warn":
359+
warn(
360+
"Not enough samples in the data, such that there is a 0 sample contingency table"
361+
)
362+
elif on_error == "ignore":
363+
# If one of the values is 0 in the 2x2 table.
364+
if isinstance(z_state, str):
365+
logging.info(
366+
f"Skipping the test X \u27C2 Y | Z={z_state}. Not enough samples"
367+
)
368+
else:
369+
z_str = ", ".join(
370+
[f"{var}={state}" for var, state in zip(Z_columns, z_state)]
371+
)
372+
logging.info(f"Skipping the test X \u27C2 Y | {z_str}. Not enough samples")
333373

334374
if np.isnan(c):
335375
raise RuntimeError(

tests/test_power_divergence.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,7 @@ def test_g_discrete():
213213
dm = np.array([testdata.dis_data]).reshape((2000, 25))
214214
df = pd.DataFrame.from_records(dm)
215215
sets = [[2, 3, 4, 5, 6, 7]]
216-
with pytest.raises(RuntimeError, match="Not enough samples"):
216+
with pytest.warns(UserWarning, match="Not enough samples"):
217217
power_divergence.condind(
218218
X=df[x], Y=df[y], condition_on=df[sets[0]], method="log-likelihood"
219219
)
@@ -247,7 +247,7 @@ def test_g_binary():
247247
dm = np.array([testdata.bin_data]).reshape((500, 50))
248248
df = pd.DataFrame.from_records(dm)
249249
sets = [[2, 3, 4, 5, 6, 7, 8]]
250-
with pytest.raises(RuntimeError, match="Not enough samples"):
250+
with pytest.warns(UserWarning, match="Not enough samples"):
251251
power_divergence.condind(
252252
X=df[x], Y=df[y], condition_on=df[sets[0]], method="log-likelihood"
253253
)

0 commit comments

Comments
 (0)