Skip to content

Commit 4071d0e

Browse files
Copilotcrvernon
andauthored
Make Interaction column optional in fit_logit (#164)
* Initial plan * Make Interaction column optional in fit_logit and add documentation Co-authored-by: crvernon <3947069+crvernon@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: crvernon <3947069+crvernon@users.noreply.github.com>
1 parent bd83d12 commit 4071d0e

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

msdbook/tests/test_utils.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,4 +179,31 @@ def test_fit_logit_comprehensive(sample_data):
179179
assert np.all(np.isfinite(result.pvalues)) # P-values should be finite numbers
180180
# Check if any coefficient has a p-value less than 0.1 (10% significance level)
181181
assert np.any(result.pvalues < 0.1)
182+
183+
184+
def test_fit_logit_without_interaction():
185+
"""Test that fit_logit works without an Interaction column."""
186+
np.random.seed(42)
187+
n = 100
188+
189+
# Create data WITHOUT an Interaction column
190+
df_no_interaction = pd.DataFrame({
191+
'Success': np.random.randint(0, 2, size=n),
192+
'Predictor1': np.random.randn(n),
193+
'Predictor2': np.random.randn(n)
194+
})
195+
196+
# This should work without raising a KeyError
197+
result = fit_logit(df_no_interaction, ['Predictor1', 'Predictor2'])
198+
199+
# Verify the result is valid
200+
assert result is not None
201+
assert hasattr(result, 'params')
202+
203+
# Should have 3 parameters: Intercept, Predictor1, Predictor2 (no Interaction)
204+
assert len(result.params) == 3
205+
assert 'Intercept' in result.params.index
206+
assert 'Predictor1' in result.params.index
207+
assert 'Predictor2' in result.params.index
208+
assert 'Interaction' not in result.params.index
182209

msdbook/utils.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,37 @@
44
import statsmodels.api as sm
55

66
def fit_logit(dta, predictors):
7-
"""Logistic regression"""
7+
"""Logistic regression with optional interaction term.
8+
9+
Parameters
10+
----------
11+
dta : pandas.DataFrame
12+
Input data containing the 'Success' column and predictor columns.
13+
If an 'Interaction' column is present, it will be included in the model.
14+
predictors : list of str
15+
List of predictor column names to include in the model.
16+
17+
Returns
18+
-------
19+
statsmodels.discrete.discrete_model.BinaryResultsWrapper
20+
Fitted logistic regression model.
21+
22+
Notes
23+
-----
24+
The function automatically adds an intercept column. If the data contains
25+
an 'Interaction' column, it will be included in the model alongside the
26+
specified predictors.
27+
"""
828

929
# Add intercept column of 1s
1030
dta["Intercept"] = np.ones(np.shape(dta)[0])
1131

12-
# Get columns of predictors
13-
cols = dta.columns.tolist()[-1:] + predictors + ["Interaction"]
32+
# Get columns of predictors, starting with intercept
33+
cols = dta.columns.tolist()[-1:] + predictors
34+
35+
# Add interaction term if present in the data
36+
if "Interaction" in dta.columns:
37+
cols = cols + ["Interaction"]
1438

1539
# Fit logistic regression without the deprecated 'disp' argument
1640
logit = sm.Logit(dta["Success"], dta[cols])

0 commit comments

Comments
 (0)