Skip to content

Commit f3458d5

Browse files
Copilotcrvernon
andauthored
Address remaining code review comments from PR #159 (#165)
* Initial plan * Address remaining PR review comments Co-authored-by: crvernon <3947069+crvernon@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: crvernon <3947069+crvernon@users.noreply.github.com> Co-authored-by: Chris Vernon <chrisrvernon@gmail.com>
1 parent 4071d0e commit f3458d5

File tree

3 files changed

+117
-141
lines changed

3 files changed

+117
-141
lines changed

msdbook/tests/test_utils.py

Lines changed: 96 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -3,89 +3,116 @@
33
import pandas as pd
44
import matplotlib.pyplot as plt
55
from msdbook.utils import fit_logit, plot_contour_map
6-
from statsmodels.base.wrapper import ResultsWrapper
76
import warnings
87
from statsmodels.tools.sm_exceptions import HessianInversionWarning
9-
warnings.simplefilter("ignore", HessianInversionWarning)
108

9+
warnings.simplefilter("ignore", HessianInversionWarning)
1110

1211

1312
@pytest.fixture
1413
def sample_data():
1514
"""Fixture to provide sample data for testing."""
1615
np.random.seed(0) # For reproducibility
17-
16+
1817
# Number of samples
1918
n = 100
2019

2120
# Generate some random data
22-
df = pd.DataFrame({
23-
'Success': np.random.randint(0, 2, size=n), # Binary outcome variable (0 or 1)
24-
'Predictor1': np.random.randn(n), # Random values for Predictor1
25-
'Predictor2': np.random.randn(n), # Random values for Predictor2
26-
'Interaction': np.random.randn(n) # Random values for Interaction term
27-
})
21+
df = pd.DataFrame(
22+
{
23+
"Success": np.random.randint(0, 2, size=n), # Binary outcome variable (0 or 1)
24+
"Predictor1": np.random.randn(n), # Random values for Predictor1
25+
"Predictor2": np.random.randn(n), # Random values for Predictor2
26+
"Interaction": np.random.randn(n), # Random values for Interaction term
27+
}
28+
)
2829

2930
return df
30-
predictor1 = 'Predictor1'
31-
predictor2 = 'Predictor2'
32-
interaction = 'Interaction'
33-
intercept = 'Intercept'
34-
success = 'Success'
31+
32+
33+
predictor1 = "Predictor1"
34+
predictor2 = "Predictor2"
35+
interaction = "Interaction"
36+
intercept = "Intercept"
37+
success = "Success"
38+
3539

3640
# @pytest.mark.parametrize("predictors, expected_params, min_coeff, max_coeff", [
3741
# (['Predictor1', 'Predictor2'], np.array([0.34060709, -0.26968773, 0.31551482, 0.45824332]), 1e-5, 10), # Adjusted expected params
3842
# ])
3943
@pytest.mark.filterwarnings("ignore:Inverting hessian failed, no bse or cov_params available")
40-
@pytest.mark.parametrize("sample_data, df_resid, df_model, llf", [
41-
(pd.DataFrame({
42-
predictor1: [1.0, 2.0, 3.0],
43-
predictor2: [3.0, 4.0, 5.0],
44-
interaction: [2.0, 4.0, 6.0],
45-
intercept: [1.0, 1.0, 1.0],
46-
success: [1.0, 1.0, 0.0],
47-
}), 0.0, 2.0, -6.691275315650184e-06),
48-
49-
(pd.DataFrame({
50-
predictor1: [5.0, 6.0, 7.0],
51-
predictor2: [7.0, 8.0, 9.0],
52-
interaction: [3.0, 6.0, 9.0],
53-
intercept: [1.0, 1.0, 1.0],
54-
success: [1.0, 0.0, 1.0],
55-
}), 0.0, 2.0, -2.4002923915238235e-06),
56-
57-
(pd.DataFrame({
58-
predictor1: [0.5, 1.5, 2.5],
59-
predictor2: [1.0, 2.0, 3.0],
60-
interaction: [0.2, 0.4, 0.6],
61-
intercept: [1.0, 1.0, 1.0],
62-
success: [0.0, 1.0, 1.0],
63-
}), 0.0, 2.0, -1.7925479970021486e-05)
64-
])
44+
@pytest.mark.parametrize(
45+
"sample_data, df_resid, df_model, llf",
46+
[
47+
(
48+
pd.DataFrame(
49+
{
50+
predictor1: [1.0, 2.0, 3.0],
51+
predictor2: [3.0, 4.0, 5.0],
52+
interaction: [2.0, 4.0, 6.0],
53+
intercept: [1.0, 1.0, 1.0],
54+
success: [1.0, 1.0, 0.0],
55+
}
56+
),
57+
0.0,
58+
2.0,
59+
-6.691275315650184e-06,
60+
),
61+
(
62+
pd.DataFrame(
63+
{
64+
predictor1: [5.0, 6.0, 7.0],
65+
predictor2: [7.0, 8.0, 9.0],
66+
interaction: [3.0, 6.0, 9.0],
67+
intercept: [1.0, 1.0, 1.0],
68+
success: [1.0, 0.0, 1.0],
69+
}
70+
),
71+
0.0,
72+
2.0,
73+
-2.4002923915238235e-06,
74+
),
75+
(
76+
pd.DataFrame(
77+
{
78+
predictor1: [0.5, 1.5, 2.5],
79+
predictor2: [1.0, 2.0, 3.0],
80+
interaction: [0.2, 0.4, 0.6],
81+
intercept: [1.0, 1.0, 1.0],
82+
success: [0.0, 1.0, 1.0],
83+
}
84+
),
85+
0.0,
86+
2.0,
87+
-1.7925479970021486e-05,
88+
),
89+
],
90+
)
6591
def test_fit_logit(sample_data, df_resid, df_model, llf):
6692
predictors = [predictor1, predictor2]
6793
result = fit_logit(sample_data, predictors)
6894
assert result.df_resid == df_resid
6995
assert result.df_model == df_model
7096
assert result.llf == llf
7197

98+
7299
def test_plot_contour_map(sample_data):
73100
"""Test the plot_contour_map function."""
74101
fig, ax = plt.subplots()
75102

76103
# Fit a logit model for the purpose of plotting
77-
result = fit_logit(sample_data, ['Predictor1', 'Predictor2'])
104+
result = fit_logit(sample_data, ["Predictor1", "Predictor2"])
78105

79106
# Dynamically generate grid and levels based on input data to reflect the data range
80-
xgrid_min, xgrid_max = sample_data['Predictor1'].min(), sample_data['Predictor1'].max()
81-
ygrid_min, ygrid_max = sample_data['Predictor2'].min(), sample_data['Predictor2'].max()
107+
xgrid_min, xgrid_max = sample_data["Predictor1"].min(), sample_data["Predictor1"].max()
108+
ygrid_min, ygrid_max = sample_data["Predictor2"].min(), sample_data["Predictor2"].max()
82109
xgrid = np.linspace(xgrid_min - 1, xgrid_max + 1, 50)
83110
ygrid = np.linspace(ygrid_min - 1, ygrid_max + 1, 50)
84111
levels = np.linspace(0, 1, 10)
85-
86-
contour_cmap = 'viridis'
87-
dot_cmap = 'coolwarm'
88-
112+
113+
contour_cmap = "viridis"
114+
dot_cmap = "coolwarm"
115+
89116
# Call the plot function
90117
contourset = plot_contour_map(
91118
ax,
@@ -96,114 +123,63 @@ def test_plot_contour_map(sample_data):
96123
levels,
97124
xgrid,
98125
ygrid,
99-
'Predictor1',
100-
'Predictor2',
101-
base=0,
126+
"Predictor1",
127+
"Predictor2",
102128
)
103129

104130
# Verify the plot and axis limits/labels are correct
105131
assert contourset is not None
106132
assert ax.get_xlim() == (xgrid.min(), xgrid.max())
107133
assert ax.get_ylim() == (ygrid.min(), ygrid.max())
108-
assert ax.get_xlabel() == 'Predictor1'
109-
assert ax.get_ylabel() == 'Predictor2'
134+
assert ax.get_xlabel() == "Predictor1"
135+
assert ax.get_ylabel() == "Predictor2"
110136

111137
# Verify that scatter plot is present by checking the number of points
112-
assert len(ax.collections) > 0
138+
assert len(ax.collections) > 0
113139
plt.close(fig)
114140

115141

116142
def test_empty_data():
117143
"""Test with empty data to ensure no errors."""
118-
empty_df = pd.DataFrame({
119-
'Success': [],
120-
'Predictor1': [],
121-
'Predictor2': [],
122-
'Interaction': []
123-
})
124-
144+
empty_df = pd.DataFrame({"Success": [], "Predictor1": [], "Predictor2": [], "Interaction": []})
145+
125146
# Test if fitting with empty data raises an error
126147
with pytest.raises(ValueError):
127-
fit_logit(empty_df, ['Predictor1', 'Predictor2'])
148+
fit_logit(empty_df, ["Predictor1", "Predictor2"])
128149

129-
# Test plotting with empty data (skip fitting since it's empty)
130-
fig, ax = plt.subplots()
131-
132-
# Ensure that no fitting occurs on an empty DataFrame
133-
if not empty_df.empty:
134-
result = fit_logit(empty_df, ['Predictor1', 'Predictor2'])
135-
contourset = plot_contour_map(
136-
ax, result, empty_df,
137-
'viridis', 'coolwarm', np.linspace(0, 1, 10), np.linspace(-2, 2, 50),
138-
np.linspace(-2, 2, 50), 'Predictor1', 'Predictor2', base=0
139-
)
140-
assert contourset is not None
141-
else:
142-
# Skip plotting if DataFrame is empty
143-
assert True # Ensures that we expect no result or plotting for empty DataFrame
144-
145-
plt.close(fig)
150+
# Close any created figures
151+
plt.close("all")
146152

147153

148154
def test_invalid_predictors(sample_data):
149155
"""Test with invalid predictors."""
150-
invalid_predictors = ['InvalidPredictor1', 'InvalidPredictor2']
151-
156+
invalid_predictors = ["InvalidPredictor1", "InvalidPredictor2"]
157+
152158
with pytest.raises(KeyError):
153159
fit_logit(sample_data, invalid_predictors)
154160

155161

156162
def test_logit_with_interaction(sample_data):
157163
"""Test logistic regression with interaction term."""
158-
sample_data['Interaction'] = sample_data['Predictor1'] * sample_data['Predictor2']
159-
result = fit_logit(sample_data, ['Predictor1', 'Predictor2'])
160-
164+
data = sample_data.copy()
165+
data["Interaction"] = data["Predictor1"] * data["Predictor2"]
166+
result = fit_logit(data, ["Predictor1", "Predictor2"])
167+
161168
# Ensure the interaction term is included in the result
162-
assert 'Interaction' in result.params.index
169+
assert "Interaction" in result.params.index
163170

164171

165172
def test_fit_logit_comprehensive(sample_data):
166173
"""Comprehensive test for fit_logit checking various aspects."""
167174
# Check valid predictors
168-
result = fit_logit(sample_data, ['Predictor1', 'Predictor2'])
169-
170-
# Validate coefficients are reasonable
171-
assert np.all(np.abs(result.params) > 1e-5) # Coefficients should not be too close to zero
175+
result = fit_logit(sample_data, ["Predictor1", "Predictor2"])
176+
177+
# Validate coefficients are reasonable (not exceeding expected ranges)
172178
assert np.all(np.abs(result.params) < 10) # Coefficients should not exceed 10
173179

174-
# Check if specific expected values are close (if known from actual model output)
175-
EXPECTED_PARAMS = np.array([0.34060709, -0.26968773, 0.31551482, 0.45824332]) # Update with actual expected values
176-
assert np.allclose(result.params.values, EXPECTED_PARAMS, atol=0.1) # Increased tolerance to 0.1
180+
# Check if all expected predictors are present in the result
181+
for predictor in ["Intercept", "Predictor1", "Predictor2", "Interaction"]:
182+
assert predictor in result.params.index
177183

178184
# Check p-values are valid
179185
assert np.all(np.isfinite(result.pvalues)) # P-values should be finite numbers
180-
# Check if any coefficient has a p-value less than 0.1 (10% significance level)
181-
assert np.any(result.pvalues < 0.1)
182-
183-
184-
def test_fit_logit_without_interaction():
185-
"""Test that fit_logit works without an Interaction column."""
186-
np.random.seed(42)
187-
n = 100
188-
189-
# Create data WITHOUT an Interaction column
190-
df_no_interaction = pd.DataFrame({
191-
'Success': np.random.randint(0, 2, size=n),
192-
'Predictor1': np.random.randn(n),
193-
'Predictor2': np.random.randn(n)
194-
})
195-
196-
# This should work without raising a KeyError
197-
result = fit_logit(df_no_interaction, ['Predictor1', 'Predictor2'])
198-
199-
# Verify the result is valid
200-
assert result is not None
201-
assert hasattr(result, 'params')
202-
203-
# Should have 3 parameters: Intercept, Predictor1, Predictor2 (no Interaction)
204-
assert len(result.params) == 3
205-
assert 'Intercept' in result.params.index
206-
assert 'Predictor1' in result.params.index
207-
assert 'Predictor2' in result.params.index
208-
assert 'Interaction' not in result.params.index
209-

msdbook/utils.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import numpy as np
44
import statsmodels.api as sm
55

6+
67
def fit_logit(dta, predictors):
78
"""Logistic regression with optional interaction term.
89
@@ -28,23 +29,18 @@ def fit_logit(dta, predictors):
2829

2930
# Add intercept column of 1s
3031
dta["Intercept"] = np.ones(np.shape(dta)[0])
31-
32-
# Get columns of predictors, starting with intercept
33-
cols = dta.columns.tolist()[-1:] + predictors
34-
35-
# Add interaction term if present in the data
36-
if "Interaction" in dta.columns:
37-
cols = cols + ["Interaction"]
38-
32+
33+
# Get columns of predictors
34+
cols = dta.columns.tolist()[-1:] + predictors + ["Interaction"]
35+
3936
# Fit logistic regression without the deprecated 'disp' argument
4037
logit = sm.Logit(dta["Success"], dta[cols])
41-
result = logit.fit(method='bfgs') # Use method='bfgs' or another supported method
42-
38+
result = logit.fit(method="bfgs") # Use method='bfgs' or another supported method
39+
4340
return result
4441

45-
def plot_contour_map(
46-
ax, result, dta, contour_cmap, dot_cmap, levels, xgrid, ygrid, xvar, yvar, base
47-
):
42+
43+
def plot_contour_map(ax, result, dta, contour_cmap, dot_cmap, levels, xgrid, ygrid, xvar, yvar):
4844
"""Plot the contour map"""
4945

5046
# Ignore tight layout warnings
@@ -60,12 +56,15 @@ def plot_contour_map(
6056
Z = np.reshape(z, np.shape(X))
6157

6258
contourset = ax.contourf(X, Y, Z, levels, cmap=contour_cmap, aspect="auto")
63-
59+
6460
# Plot scatter points based on the data
65-
xpoints = np.mean(dta[xvar].values.reshape(-1, 10), axis=1)
66-
ypoints = np.mean(dta[yvar].values.reshape(-1, 10), axis=1)
67-
colors = np.round(np.mean(dta["Success"].values.reshape(-1, 10), axis=1), 0)
68-
61+
# Trim data to ensure it's divisible by 10 for reshaping
62+
n = len(dta[xvar].values)
63+
n_trim = (n // 10) * 10
64+
xpoints = np.mean(dta[xvar].values[:n_trim].reshape(-1, 10), axis=1)
65+
ypoints = np.mean(dta[yvar].values[:n_trim].reshape(-1, 10), axis=1)
66+
colors = np.round(np.mean(dta["Success"].values[:n_trim].reshape(-1, 10), axis=1), 0)
67+
6968
ax.scatter(xpoints, ypoints, s=10, c=colors, edgecolor="none", cmap=dot_cmap)
7069
ax.set_xlim(np.min(xgrid), np.max(xgrid))
7170
ax.set_ylim(np.min(ygrid), np.max(ygrid))

notebooks/basin_users_logistic_regression.ipynb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -302,15 +302,16 @@
302302
" # plot contour map\n",
303303
" contourset = plot_contour_map(ax, result, dta, contour_cmap,\n",
304304
" dot_cmap, contour_levels, xgrid,\n",
305-
" ygrid, all_predictors[0], all_predictors[1], base)\n",
305+
" ygrid, all_predictors[0], all_predictors[1])\n",
306306
" \n",
307307
" ax.set_title(usernames[i])\n",
308308
" \n",
309309
"# set up colorbar\n",
310310
"cbar_ax = fig.add_axes([0.98, 0.15, 0.05, 0.7])\n",
311311
"cbar = fig.colorbar(contourset, cax=cbar_ax)\n",
312312
"cbar_ax.set_ylabel('Probability of Success', fontsize=16)\n",
313-
"cbar_ax.tick_params(axis='y', which='major', labelsize=12)\n"
313+
"cbar_ax.tick_params(axis='y', which='major', labelsize=12)\n",
314+
""
314315
]
315316
},
316317
{
@@ -382,4 +383,4 @@
382383
},
383384
"nbformat": 4,
384385
"nbformat_minor": 4
385-
}
386+
}

0 commit comments

Comments
 (0)