Skip to content

Commit 23af014

Browse files
use GPUs for GLMs Pt.2
1 parent 7c4136f commit 23af014

File tree

4 files changed

+16
-165
lines changed

4 files changed

+16
-165
lines changed

examples/glm_regression_quantile_gpopt.py

Lines changed: 0 additions & 138 deletions
This file was deleted.

nnetsauce/glm/glm.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -146,29 +146,13 @@ def compute_XB(self, X, beta=None, row_index=None):
146146
return mo.safe_sparse_dot(X[row_index, :], self.beta_, backend=self.backend)
147147

148148
def compute_XB2(self, X, beta=None, row_index=None):
149-
if self.backend != "cpu":
150-
raise NotImplementedError(
151-
"GLM.compute_XB is only implemented for backend='cpu'"
152-
)
153149
def f00(X):
154-
if self.backend != "cpu":
155-
raise NotImplementedError(
156-
"GLM.compute_XB is only implemented for backend='cpu'"
157-
)
158150
return mo.safe_sparse_dot(X, self.beta_, backend=self.backend)
159151

160152
def f01(X):
161-
if self.backend != "cpu":
162-
raise NotImplementedError(
163-
"GLM.compute_XB is only implemented for backend='cpu'"
164-
)
165153
return mo.safe_sparse_dot(X[row_index, :], self.beta_, backend=self.backend)
166154

167155
def f11(X):
168-
if self.backend != "cpu":
169-
raise NotImplementedError(
170-
"GLM.compute_XB is only implemented for backend='cpu'"
171-
)
172156
return mo.safe_sparse_dot(X[row_index, :], beta, backend=self.backend)
173157

174158
def f10(X):
@@ -186,10 +170,6 @@ def f10(X):
186170
return h_result[result_code](X)
187171

188172
def penalty(self, beta1, beta2, lambda1, lambda2, alpha1, alpha2):
189-
if self.backend != "cpu":
190-
raise NotImplementedError(
191-
"GLM.compute_XB is only implemented for backend='cpu'"
192-
)
193173
res = lambda1 * (
194174
0.5 * (1 - alpha1) * np.sum(np.square(beta1))
195175
+ alpha1 * np.sum(np.abs(beta1))
@@ -201,10 +181,6 @@ def penalty(self, beta1, beta2, lambda1, lambda2, alpha1, alpha2):
201181
return res
202182

203183
def compute_penalty(self, group_index, beta):
204-
if self.backend != "cpu":
205-
raise NotImplementedError(
206-
"GLM.compute_XB is only implemented for backend='cpu'"
207-
)
208184
return self.penalty(
209185
beta1=beta[0:group_index],
210186
beta2=beta[group_index: len(beta)],

nnetsauce/glm/glmClassifier.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
from sklearn.base import ClassifierMixin
1111
from ..optimizers import Optimizer
1212
from scipy.special import logsumexp, expit, erf
13-
13+
try:
14+
import jax.numpy as jnp
15+
except ImportError:
16+
pass
1417

1518
class GLMClassifier(GLM, ClassifierMixin):
1619
"""Generalized 'linear' models using quasi-randomized networks (classification)
@@ -244,7 +247,10 @@ def fit(self, X, y, **kwargs):
244247
Y = self.optimizer.one_hot_encode(output_y, self.n_classes)
245248

246249
# initialization
247-
beta_ = np.linalg.lstsq(scaled_Z, Y, rcond=None)[0]
250+
if self.backend == "cpu":
251+
beta_ = np.linalg.lstsq(scaled_Z, Y, rcond=None)[0]
252+
else:
253+
beta_ = jnp.linalg.lstsq(scaled_Z, Y, rcond=None)[0]
248254

249255
# optimization
250256
# fit(self, loss_func, response, x0, **kwargs):

nnetsauce/glm/glmRegressor.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
from ..utils import matrixops as mo
1010
from sklearn.base import RegressorMixin
1111
from ..optimizers import Optimizer
12+
try:
13+
import jax.numpy as jnp
14+
except ImportError:
15+
pass
1216

1317

1418
class GLMRegressor(GLM, RegressorMixin):
@@ -239,7 +243,10 @@ def fit(self, X, y, **kwargs):
239243

240244
centered_y, scaled_Z = self.cook_training_set(y=y, X=X, **kwargs)
241245
# initialization
242-
beta_ = np.linalg.lstsq(scaled_Z, centered_y, rcond=None)[0]
246+
if self.backend == "cpu":
247+
beta_ = np.linalg.lstsq(scaled_Z, centered_y, rcond=None)[0]
248+
else:
249+
beta_ = jnp.linalg.lstsq(scaled_Z, centered_y, rcond=None)[0]
243250
# optimization
244251
# fit(self, loss_func, response, x0, **kwargs):
245252
# loss_func(self, beta, group_index, X, y,

0 commit comments

Comments
 (0)