Skip to content

Commit f9a0b07

Browse files
adam2392bloebp
authored andcommitted
Make sure coverage is fine
Signed-off-by: Adam Li <[email protected]>
1 parent ff75365 commit f9a0b07

File tree

2 files changed

+42
-5
lines changed

2 files changed

+42
-5
lines changed

pywhy_stats/power_divergence.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ def _preprocess_inputs(X: ArrayLike, Y: ArrayLike, Z: Optional[ArrayLike]) -> Ar
155155
Y = np.asarray(Y)
156156

157157
if not all(type(xi) == type(X[0]) for xi in X): # noqa
158-
raise ValueError("All elements of X must be of the same type.")
158+
raise TypeError("All elements of X must be of the same type.")
159159
if not all(type(yi) == type(Y[0]) for yi in Y): # noqa
160-
raise ValueError("All elements of Y must be of the same type.")
160+
raise TypeError("All elements of Y must be of the same type.")
161161

162162
# Check if all elements are integers
163163
if np.issubdtype(type(X[0]), np.str_):
@@ -185,7 +185,7 @@ def _preprocess_inputs(X: ArrayLike, Y: ArrayLike, Z: Optional[ArrayLike]) -> Ar
185185
Z = Z.reshape(-1, 1)
186186
for icol in range(Z.shape[1]):
187187
if not all(type(zi) == type(Z[0, icol]) for zi in Z[:, icol]): # noqa
188-
raise ValueError(f"All elements of Z in column {icol} must be of the same type.")
188+
raise TypeError(f"All elements of Z in column {icol} must be of the same type.")
189189

190190
# XXX: needed when converting to only numpy API
191191
# Check if all elements are integers
@@ -271,7 +271,9 @@ def _power_divergence(
271271
if Z is None:
272272
# Compute the contingency table
273273
observed_xy, _, _ = np.histogram2d(X, Y, bins=(np.unique(X).size, np.unique(Y).size))
274-
chi, p_value, dof, expected = stats.chi2_contingency(observed_xy, correction=correction, lambda_=method)
274+
chi, p_value, dof, expected = stats.chi2_contingency(
275+
observed_xy, correction=correction, lambda_=method
276+
)
275277

276278
# Step 2: If there are conditionals variables, iterate over unique states and do
277279
# the contingency test.
@@ -316,7 +318,9 @@ def _power_divergence(
316318
sub_table_z = (
317319
df.groupby(X_columns + Y_columns).size().unstack(Y_columns, fill_value=1e-7)
318320
)
319-
c, _, d, _ = stats.chi2_contingency(sub_table_z, correction=correction, lambda_=method)
321+
c, _, d, _ = stats.chi2_contingency(
322+
sub_table_z, correction=correction, lambda_=method
323+
)
320324
chi += c
321325
dof += d
322326
except ValueError:

tests/test_power_divergence.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,3 +292,36 @@ def test_g_binary_highdim():
292292
X=df["x"], Y=df["y"], condition_on=df[list(range(5)) + ["x1"]], method="log-likelihood"
293293
)
294294
assert result.pvalue < 0.05
295+
296+
297+
class TestPreprocessInputs:
298+
"""Test error cases in preprocessing categorical data input."""
299+
300+
# Test for a valid case with mixed integer and string input arrays
301+
def test_preprocess_inputs_mixed(self):
302+
X = np.array([1, "red", 3, "green"], dtype="object")
303+
Y = np.array(["blue", 5, "yellow", 7], dtype="object")
304+
with pytest.raises(TypeError):
305+
power_divergence.ind(X, Y)
306+
307+
# Test for invalid case with 2D array as input
308+
def test_preprocess_inputs_2d(self):
309+
X = np.array([[1, 2], [3, 4]])
310+
Y = np.array([[5, 6], [7, 8]])
311+
Z = None
312+
with pytest.raises(ValueError):
313+
power_divergence.condind(X, Y, Z)
314+
315+
# Test for invalid case with unsupported data type in X array
316+
def test_preprocess_inputs_invalid_X_dtype(self):
317+
X = np.array([1, "red", 3, 4.5], dtype="object")
318+
Y = np.array(["blue", "green", "yellow", "orange"])
319+
with pytest.raises(TypeError):
320+
power_divergence.ind(X, Y)
321+
322+
# Test for invalid case with unsupported data type in Y array
323+
def test_preprocess_inputs_invalid_Y_dtype(self):
324+
X = np.array([1, 2, 3, 4])
325+
Y = np.array(["blue", 5, "yellow", 7.5], dtype="object")
326+
with pytest.raises(TypeError):
327+
power_divergence.ind(X, Y)

0 commit comments

Comments
 (0)