Skip to content

Commit 959288d

Browse files
mlondschienstanmartlbittarello
authored
Use dtype dependent precision (#844)
Co-authored-by: Martin Stancsics <martin.stancsics@quantco.com> Co-authored-by: Luca Bittarello <15511539+lbittarello@users.noreply.github.com>
1 parent df6f372 commit 959288d

File tree

4 files changed

+15
-6
lines changed

4 files changed

+15
-6
lines changed

CHANGELOG.rst

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,10 @@ Changelog
1616
- New fitted attributes ``col_means_`` and ``col_stds_`` for classes :class:`~glum.GeneralizedLinearRegressor` and :class:`~glum.GeneralizedLinearRegressorCV`.
1717
- :class:`~glum.GeneralizedLinearRegressor` now prints more informative logs when fitting with ``alpha_search=True`` and ``verbose=True``.
1818

19-
**Bug fix:
19+
**Bug fixes:**
2020

2121
- Fixed a bug where :meth:`glum.GeneralizedLinearRegressor.fit` would raise a ``dtype`` mismatch error if fit with ``alpha_search=True``.
22+
- Use data type (``float64`` or ``float32``) dependent precision in solvers.
2223

2324
3.0.2 - 2024-06-25
2425
------------------

src/glum/_cd_fast.pyx

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ def enet_coordinate_descent_gram(int[::1] active_set,
117117
bint has_lower_bounds,
118118
floating[:] lower_bounds,
119119
bint has_upper_bounds,
120-
floating[:] upper_bounds):
120+
floating[:] upper_bounds,
121+
floating eps):
121122
"""Cython version of the coordinate descent algorithm
122123
for Elastic-Net regression
123124
We minimize
@@ -162,7 +163,7 @@ def enet_coordinate_descent_gram(int[::1] active_set,
162163
else:
163164
P1_ii = P1[ii - intercept]
164165

165-
if Q[active_set_ii, active_set_ii] == 0.0:
166+
if Q[active_set_ii, active_set_ii] <= eps:
166167
continue
167168

168169
w_ii = w[ii] # Store previous value

src/glum/_glm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ def _one_over_var_inf_to_val(arr: np.ndarray, val: float) -> np.ndarray:
452452
453453
If values are zeros, return val.
454454
"""
455-
zeros = np.where(np.abs(arr) < 1e-7)
455+
zeros = np.where(np.abs(arr) < np.sqrt(np.finfo(arr.dtype).eps))
456456
with np.errstate(divide="ignore"):
457457
one_over = 1 / arr
458458
one_over[zeros] = val

src/glum/_solvers.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ def _cd_solver(state, data, active_hessian):
7070
data._lower_bounds,
7171
data.has_upper_bounds,
7272
data._upper_bounds,
73+
np.finfo(state.coef.dtype).eps * 16,
7374
)
7475
return new_coef - state.coef, n_cycles
7576

@@ -546,6 +547,9 @@ def __init__(self, coef, data):
546547
self.line_search_runtime = None
547548
self.quadratic_update_runtime = None
548549

550+
# used in the line-search Armijo stopping criterion
551+
self.large_number = 1e30 if data.X.dtype == np.float32 else 1e43
552+
549553
def _record_iteration(self):
550554
self.n_iter += 1
551555

@@ -759,7 +763,9 @@ def line_search(state: IRLSState, data: IRLSData, d: np.ndarray):
759763
"""
760764
# line search parameters
761765
(beta, sigma) = (0.5, 0.0001)
762-
eps = 16 * np.finfo(state.obj_val.dtype).eps # type: ignore
766+
# Use np.finfo(state.coef.dtype).eps instead np.finfo(state.obj_val), as
767+
# state.obj_val is np.float64, even if the data is np.float32.
768+
eps = 16 * np.finfo(state.coef.dtype).eps # type: ignore
763769

764770
# line search by sequence beta^k, k=0, 1, ..
765771
# F(w + lambda d) - F(w) <= lambda * bound
@@ -792,7 +798,8 @@ def line_search(state: IRLSState, data: IRLSData, d: np.ndarray):
792798
)
793799
# 1. Check Armijo / sufficient decrease condition.
794800
loss_improvement = obj_val_wd - state.obj_val
795-
if mu_wd.max() < 1e43 and loss_improvement <= factor * bound:
801+
802+
if mu_wd.max() < state.large_number and loss_improvement <= factor * bound:
796803
break
797804
# 2. Deal with relative loss differences around machine precision.
798805
tiny_loss = np.abs(state.obj_val * eps) # type: ignore

0 commit comments

Comments
 (0)