Skip to content

Commit c8923e0

Browse files
authored
Combine parsing of initial values (#805)
1 parent 6507314 commit c8923e0

File tree

2 files changed

+37
-81
lines changed

2 files changed

+37
-81
lines changed

src/glum/_glm.py

Lines changed: 35 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
_assert_all_finite,
3939
check_consistent_length,
4040
check_is_fitted,
41-
check_random_state,
4241
column_or_1d,
4342
)
4443

@@ -99,7 +98,7 @@ class WaldTestResult(NamedTuple):
9998
df: int
10099

101100

102-
def check_array_tabmat_compliant(mat: ArrayLike, drop_first: int = False, **kwargs):
101+
def check_array_tabmat_compliant(mat: ArrayLike, drop_first: bool = False, **kwargs):
103102
to_copy = kwargs.get("copy", False)
104103

105104
if isinstance(mat, pd.DataFrame):
@@ -654,33 +653,6 @@ def _setup_sparse_p2(P2):
654653
return P2
655654

656655

657-
def initialize_start_params(
658-
start_params: Optional[np.ndarray], n_cols: int, fit_intercept: bool, dtype
659-
) -> Optional[np.ndarray]:
660-
if start_params is None:
661-
return None
662-
663-
start_params = check_array(
664-
start_params,
665-
accept_sparse=False,
666-
force_all_finite=True,
667-
ensure_2d=False,
668-
dtype=dtype,
669-
copy=True,
670-
)
671-
672-
start_params = cast(np.ndarray, start_params)
673-
674-
if start_params.shape != (n_cols + fit_intercept,):
675-
raise ValueError(
676-
"Start values for parameters must have the right length and dimension; "
677-
f"got (length={start_params.shape[0]}, ndim={start_params.ndim}); "
678-
f"needed (length={n_cols + fit_intercept}, ndim=1)."
679-
)
680-
681-
return start_params
682-
683-
684656
def is_pos_semidef(p: Union[sparse.spmatrix, np.ndarray]) -> Union[bool, np.bool_]:
685657
"""
686658
Checks for positive semidefiniteness of ``p`` if ``p`` is a matrix, or
@@ -833,13 +805,13 @@ def link_instance(self) -> Link:
833805

834806
def _get_start_coef(
835807
self,
836-
start_params,
837808
X: Union[tm.MatrixBase, tm.StandardizedMatrix],
838809
y: np.ndarray,
839810
sample_weight: np.ndarray,
840811
offset: Optional[np.ndarray],
841-
col_means: Optional[np.ndarray],
812+
col_means: np.ndarray,
842813
col_stds: Optional[np.ndarray],
814+
dtype,
843815
) -> np.ndarray:
844816
if self.warm_start and hasattr(self, "coef_"):
845817
coef = self.coef_ # type: ignore
@@ -849,7 +821,7 @@ def _get_start_coef(
849821
if self._center_predictors:
850822
_standardize_warm_start(coef, col_means, col_stds) # type: ignore
851823

852-
elif start_params is None:
824+
elif self.start_params is None:
853825
if self.fit_intercept:
854826
coef = np.zeros(
855827
X.shape[1] + 1, dtype=_float_itemsize_to_dtype[X.dtype.itemsize]
@@ -863,13 +835,28 @@ def _get_start_coef(
863835
)
864836

865837
else: # assign given array as start values
866-
coef = start_params
838+
coef = check_array(
839+
self.start_params,
840+
accept_sparse=False,
841+
force_all_finite=True,
842+
ensure_2d=False,
843+
dtype=dtype,
844+
copy=True,
845+
)
846+
847+
if coef.shape != (len(col_means) + self.fit_intercept,):
848+
raise ValueError(
849+
"Start values for parameters must have the right length "
850+
f"and dimension; got {coef.shape}, needed "
851+
f"({len(col_means) + self.fit_intercept},)."
852+
)
853+
867854
if self._center_predictors:
868855
_standardize_warm_start(coef, col_means, col_stds) # type: ignore
869856

870857
# If starting values are outside the specified bounds (if set),
871858
# bring the starting value exactly at the bound.
872-
idx = 1 if self.fit_intercept else 0
859+
idx = int(self.fit_intercept)
873860
if self.lower_bounds is not None:
874861
if np.any(coef[idx:] < self.lower_bounds):
875862
warnings.warn(
@@ -970,8 +957,6 @@ def _set_up_for_fit(self, y: np.ndarray) -> None:
970957
else:
971958
self._gradient_tol = self.gradient_tol
972959

973-
self._random_state = check_random_state(self.random_state)
974-
975960
# 1.4 additional validations ##########################################
976961
if self.check_input:
977962
if not np.all(self._family_instance.in_y_range(y)):
@@ -980,12 +965,6 @@ def _set_up_for_fit(self, y: np.ndarray) -> None:
980965
f"{self._family_instance.__class__.__name__}."
981966
)
982967

983-
def _tear_down_from_fit(self):
984-
"""
985-
Delete attributes that were only needed for the fit method.
986-
"""
987-
del self._random_state
988-
989968
def _get_alpha_path(
990969
self,
991970
P1_no_alpha: np.ndarray,
@@ -1083,8 +1062,8 @@ def _solve(
10831062
b_ineq: Optional[np.ndarray],
10841063
) -> np.ndarray:
10851064
"""
1086-
Must be run after running :func:`_set_up_for_fit` and before running
1087-
:func:`_tear_down_from_fit`. Sets ``self.coef_`` and ``self.intercept_``.
1065+
Must be run after running :func:`_set_up_for_fit`. Sets
1066+
``self.coef_`` and ``self.intercept_``.
10881067
"""
10891068
fixed_inner_tol = None
10901069
if (
@@ -1527,7 +1506,7 @@ def coef_table(
15271506
captured_context = capture_context(
15281507
context + 1 if isinstance(context, int) else context
15291508
)
1530-
if (X is None) and not hasattr(self, "covariance_matrix_"):
1509+
if (X is None) and (getattr(self, "covariance_matrix_", None) is None):
15311510
return pd.Series(beta, index=names, name="coef")
15321511

15331512
covariance_matrix = self.covariance_matrix(
@@ -2374,10 +2353,9 @@ def _should_copy_X(self):
23742353
def _set_up_and_check_fit_args(
23752354
self,
23762355
X: ArrayLike,
2377-
y: ArrayLike,
2356+
y: Optional[ArrayLike],
23782357
sample_weight: Optional[VectorLike],
23792358
offset: Optional[VectorLike],
2380-
solver: str,
23812359
force_all_finite,
23822360
context: Optional[Mapping[str, Any]] = None,
23832361
) -> tuple[
@@ -2390,7 +2368,7 @@ def _set_up_and_check_fit_args(
23902368
Union[str, np.ndarray],
23912369
]:
23922370
dtype = [np.float64, np.float32]
2393-
stype = ["csc"] if solver == "irls-cd" else ["csc", "csr"]
2371+
stype = ["csc"] if self.solver == "irls-cd" else ["csc", "csr"]
23942372

23952373
P1 = self.P1
23962374
P2 = self.P2
@@ -2418,8 +2396,8 @@ def _set_up_and_check_fit_args(
24182396
context=context,
24192397
)
24202398

2421-
self.y_model_spec_ = y.model_spec
2422-
y = y.toarray().ravel()
2399+
self.y_model_spec_ = y.model_spec # type: ignore
2400+
y = y.toarray().ravel() # type: ignore
24232401

24242402
X = tm.from_formula(
24252403
formula=rhs,
@@ -3128,7 +3106,6 @@ def fit(
31283106
y,
31293107
sample_weight,
31303108
offset,
3131-
solver=self.solver,
31323109
force_all_finite=self.force_all_finite,
31333110
context=captured_context,
31343111
)
@@ -3154,13 +3131,6 @@ def fit(
31543131
if np.any(lower_bounds > upper_bounds):
31553132
raise ValueError("Upper bounds must be higher than lower bounds.")
31563133

3157-
start_params = initialize_start_params(
3158-
self.start_params,
3159-
n_cols=X.shape[1],
3160-
fit_intercept=self.fit_intercept,
3161-
dtype=[np.float64, np.float32],
3162-
)
3163-
31643134
# 1.4 additional validations ##########################################
31653135
if self.check_input:
31663136
# check if P2 is positive semidefinite
@@ -3204,7 +3174,13 @@ def fit(
32043174
#######################################################################
32053175

32063176
coef = self._get_start_coef(
3207-
start_params, X, y, sample_weight, offset, col_means, col_stds
3177+
X,
3178+
y,
3179+
sample_weight,
3180+
offset,
3181+
col_means,
3182+
col_stds,
3183+
dtype=[np.float64, np.float32],
32083184
)
32093185

32103186
#######################################################################
@@ -3291,8 +3267,6 @@ def fit(
32913267
col_means, col_stds, 0.0, coef
32923268
)
32933269

3294-
self._tear_down_from_fit()
3295-
32963270
self.covariance_matrix_ = None
32973271
if store_covariance_matrix:
32983272
self.covariance_matrix(

src/glum/_glm_cv.py

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
_standardize,
1616
_unstandardize,
1717
check_bounds,
18-
initialize_start_params,
1918
is_pos_semidef,
2019
setup_p1,
2120
setup_p2,
@@ -500,7 +499,6 @@ def fit(
500499
y,
501500
sample_weight,
502501
offset,
503-
solver=self.solver,
504502
force_all_finite=self.force_all_finite,
505503
context=captured_context,
506504
)
@@ -588,13 +586,6 @@ def _get_deviance(coef):
588586
):
589587
assert isinstance(self._link_instance, LogLink)
590588

591-
start_params = initialize_start_params(
592-
self.start_params,
593-
n_cols=X.shape[1],
594-
fit_intercept=self.fit_intercept,
595-
dtype=[np.float64, np.float32],
596-
)
597-
598589
P1_no_alpha = setup_p1(P1, X, X.dtype, 1, l1)
599590
P2_no_alpha = setup_p2(P2, X, _stype, X.dtype, 1, l1)
600591

@@ -620,13 +611,13 @@ def _get_deviance(coef):
620611
)
621612

622613
coef = self._get_start_coef(
623-
start_params,
624614
x_train,
625615
y_train,
626616
w_train,
627617
offset_train,
628618
col_means,
629619
col_stds,
620+
dtype=[np.float64, np.float32],
630621
)
631622

632623
if self.check_input:
@@ -748,15 +739,8 @@ def _get_deviance(coef):
748739
P2,
749740
)
750741

751-
start_params = initialize_start_params(
752-
self.start_params,
753-
n_cols=X.shape[1],
754-
fit_intercept=self.fit_intercept,
755-
dtype=X.dtype,
756-
)
757-
758742
coef = self._get_start_coef(
759-
start_params, X, y, sample_weight, offset, col_means, col_stds
743+
X, y, sample_weight, offset, col_means, col_stds, dtype=X.dtype
760744
)
761745

762746
coef = self._solve(
@@ -781,8 +765,6 @@ def _get_deviance(coef):
781765
# set intercept to zero as the other linear models do
782766
self.intercept_, self.coef_ = _unstandardize(col_means, col_stds, 0.0, coef)
783767

784-
self._tear_down_from_fit()
785-
786768
self.covariance_matrix_ = None
787769
if store_covariance_matrix:
788770
self.covariance_matrix(

0 commit comments

Comments
 (0)