Skip to content

Commit a7cdfe8

Browse files
committed
BUG: Correct handling of q in SUR constraints
Correct reindexing of q so that NaN is not introduced closes #633
1 parent 1969b5b commit a7cdfe8

File tree

6 files changed

+42
-11
lines changed

6 files changed

+42
-11
lines changed

linearmodels/panel/model.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -771,9 +771,7 @@ def _setup_clusters(
771771
cat = Categorical(formatted_clusters.dataframe[col])
772772
# TODO: Bug in pandas-stubs
773773
# https://github.com/pandas-dev/pandas-stubs/issues/111
774-
formatted_clusters.dataframe[col] = cat.codes.astype(
775-
np.int64
776-
) # type: ignore
774+
formatted_clusters.dataframe[col] = cat.codes.astype(np.int64) # type: ignore
777775
clusters_frame = formatted_clusters.dataframe
778776

779777
cluster_entity = bool(cov_config_upd.pop("cluster_entity", False))
@@ -2184,9 +2182,7 @@ def _setup_clusters(
21842182
cluster_max.T, index=index, columns=clusters_panel.vars
21852183
)
21862184
# TODO: Bug in pandas-stubs prevents using Hashable | None
2187-
clusters_frame = clusters_frame.loc[reindex].astype(
2188-
np.int64
2189-
) # type: ignore
2185+
clusters_frame = clusters_frame.loc[reindex].astype(np.int64) # type: ignore
21902186
cov_config_upd["clusters"] = clusters_frame
21912187

21922188
return cov_config_upd

linearmodels/shared/hypotheses.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _parse_single(constraint: str) -> tuple[str, float]:
198198

199199

200200
def _reparse_constraint_formula(
201-
formula: str | list[str] | dict[str, float]
201+
formula: str | list[str] | dict[str, float],
202202
) -> str | dict[str, float]:
203203
# TODO: Test against variable names constaining , or =
204204
if isinstance(formula, Mapping):

linearmodels/system/_utility.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -262,7 +262,8 @@ def __init__(
262262
raise TypeError("q must be a Series or an array")
263263
if r.shape[0] != q.shape[0]:
264264
raise ValueError("Constraint inputs are not shape compatible")
265-
q_pd = pd.Series(q, index=r_pd.index)
265+
q_pd = pd.Series(q)
266+
q_pd.index = r_pd.index
266267
else:
267268
q_pd = pd.Series(np.zeros(r_pd.shape[0]), index=r_pd.index)
268269
self._q_pd = q_pd

linearmodels/system/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@
9696

9797

9898
def _missing_weights(
99-
weights: Mapping[str, linearmodels.typing.data.ArrayLike | None]
99+
weights: Mapping[str, linearmodels.typing.data.ArrayLike | None],
100100
) -> None:
101101
"""Raise warning if missing weighs found"""
102102
missing = [key for key in weights if weights[key] is None]

linearmodels/tests/system/test_sur.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -909,3 +909,37 @@ def test_unknown_method():
909909
mod = SUR(generate_data(k=3))
910910
with pytest.raises(ValueError, match="method must be 'ols' or 'gls'"):
911911
mod.fit(method="other")
912+
913+
914+
def test_sur_contraint_with_value():
915+
n = 100
916+
rg = np.random.RandomState(np.random.MT19937(12345))
917+
x1 = rg.normal(size=n)
918+
x2 = rg.normal(size=n)
919+
x3 = rg.normal(size=n)
920+
921+
y1 = 3 + 1.5 * x1 - 2.0 * x2 + np.random.normal(size=n)
922+
y2 = -1 + 0.5 * x2 + 1.2 * x3 + np.random.normal(size=n)
923+
924+
data = DataFrame({"x1": x1, "x2": x2, "x3": x3, "y1": y1, "y2": y2})
925+
926+
equations = {"eq1": "y1 ~ x1 + x2", "eq2": "y2 ~ x2 + x3"}
927+
928+
model = SUR.from_formula(equations, data)
929+
930+
# coefficients of eq1_x1 and eq2_x2 are equal
931+
r = DataFrame(
932+
[[0] * 4], columns=model.param_names, index=["rest"], dtype=np.float64
933+
)
934+
r.iloc[0, 0] = -1
935+
r.iloc[0, 2] = 1
936+
937+
q = Series([0])
938+
model.add_constraints(r, q)
939+
result = model.fit()
940+
941+
# No error without q
942+
model = SUR.from_formula(equations, data)
943+
model.add_constraints(r)
944+
result_without_q = model.fit()
945+
assert_allclose(result.params, result_without_q.params)

requirements-dev.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
xarray>=0.16
22
mypy>=1.3
3-
black[jupyter]==24.4.0
3+
black[jupyter]==24.10.0
44
pytest>=7.3.0,<8
55
isort>=5.12
66
ipython
77
matplotlib
8-
ruff
8+
ruff>=0.8.6
99
jupyterlab-code-formatter
1010
flake8
1111
jupyter

0 commit comments

Comments
 (0)