Skip to content

Commit 369b735

Browse files
committed
TST: Add cat test with dtypes
1 parent 60a6524 commit 369b735

File tree

1 file changed

+32
-0
lines changed

1 file changed

+32
-0
lines changed

linearmodels/tests/iv/test_formulas.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,3 +367,35 @@ def test_formula_escape():
367367
assert "x 1" in res.params.index
368368
assert "y space" in str(summ)
369369
assert "Instruments: z 0" in str(summ)
370+
371+
372+
@pytest.mark.parametrize("dtype", [str, "category", object])
373+
def test_formula_categorical_equiv(data, model_and_func, dtype):
374+
model, func = model_and_func
375+
data = data.copy()
376+
rs = np.random.RandomState(12345)
377+
data["d"] = rs.choice(["a", "b", "c", "d"], size=data.shape[0])
378+
data["d"] = data["d"].astype(dtype)
379+
formula = "y ~ 1 + d + x2 + [x3 ~ z1 + z2]"
380+
mod = model.from_formula(formula, data)
381+
res = mod.fit()
382+
print(res)
383+
aug_data = data.copy()
384+
aug_data["d[T.b]"] = (data["d"] == "b").astype(float)
385+
aug_data["d[T.c]"] = (data["d"] == "c").astype(float)
386+
aug_data["d[T.d]"] = (data["d"] == "d").astype(float)
387+
exog = ["Intercept", "d[T.b]", "d[T.c]", "d[T.d]", "x2"]
388+
endog = ["x3"]
389+
instr = ["z1", "z2"]
390+
res_direct = model(
391+
aug_data["y"], aug_data[exog], aug_data[endog], aug_data[instr]
392+
).fit()
393+
assert_allclose(res.rsquared, res_direct.rsquared)
394+
assert list(res.params.index) == [
395+
"Intercept",
396+
"d[T.b]",
397+
"d[T.c]",
398+
"d[T.d]",
399+
"x2",
400+
"x3",
401+
]

0 commit comments

Comments
 (0)