Skip to content

Commit cbea008

Browse files
use GPUs for GLMs Pt.3
1 parent 23af014 commit cbea008

File tree

3 files changed

+23
-12
lines changed

3 files changed

+23
-12
lines changed

nnetsauce/glm/glm.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class GLM(Base):
7272
7373
optimizer: object
7474
optimizer, from class nnetsauce.utils.Optimizer
75-
75+
7676
backend: str.
7777
"cpu" or "gpu" or "tpu".
7878
@@ -132,34 +132,42 @@ def __init__(
132132
self.backend = backend
133133
self.beta_ = None
134134

135-
def compute_XB(self, X, beta=None, row_index=None):
135+
def compute_XB(self, X, beta=None, row_index=None):
136136
if beta is not None:
137137
if row_index is None:
138138
return mo.safe_sparse_dot(X, beta, backend=self.backend)
139139

140-
return mo.safe_sparse_dot(X[row_index, :], beta, backend=self.backend)
140+
return mo.safe_sparse_dot(
141+
X[row_index, :], beta, backend=self.backend
142+
)
141143

142144
# self.beta_ is None in this case
143145
if row_index is None:
144146
return mo.safe_sparse_dot(X, self.beta_, backend=self.backend)
145147

146-
return mo.safe_sparse_dot(X[row_index, :], self.beta_, backend=self.backend)
148+
return mo.safe_sparse_dot(
149+
X[row_index, :], self.beta_, backend=self.backend
150+
)
147151

148152
def compute_XB2(self, X, beta=None, row_index=None):
149153
def f00(X):
150154
return mo.safe_sparse_dot(X, self.beta_, backend=self.backend)
151155

152156
def f01(X):
153-
return mo.safe_sparse_dot(X[row_index, :], self.beta_, backend=self.backend)
157+
return mo.safe_sparse_dot(
158+
X[row_index, :], self.beta_, backend=self.backend
159+
)
154160

155161
def f11(X):
156-
return mo.safe_sparse_dot(X[row_index, :], beta, backend=self.backend)
162+
return mo.safe_sparse_dot(
163+
X[row_index, :], beta, backend=self.backend
164+
)
157165

158166
def f10(X):
159167
if self.backend != "cpu":
160168
raise NotImplementedError(
161-
"GLM.compute_XB is only implemented for backend='cpu'"
162-
)
169+
"GLM.compute_XB is only implemented for backend='cpu'"
170+
)
163171
return mo.safe_sparse_dot(X, beta, backend=self.backend)
164172

165173
h_result = {"00": f00, "01": f01, "11": f11, "10": f10}

nnetsauce/glm/glmClassifier.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,13 @@
1010
from sklearn.base import ClassifierMixin
1111
from ..optimizers import Optimizer
1212
from scipy.special import logsumexp, expit, erf
13-
try:
13+
14+
try:
1415
import jax.numpy as jnp
1516
except ImportError:
1617
pass
1718

19+
1820
class GLMClassifier(GLM, ClassifierMixin):
1921
"""Generalized 'linear' models using quasi-randomized networks (classification)
2022
@@ -76,7 +78,7 @@ class GLMClassifier(GLM, ClassifierMixin):
7678
7779
optimizer: object
7880
optimizer, from class nnetsauce.Optimizer
79-
81+
8082
backend: str.
8183
"cpu" or "gpu" or "tpu".
8284

nnetsauce/glm/glmRegressor.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from ..utils import matrixops as mo
1010
from sklearn.base import RegressorMixin
1111
from ..optimizers import Optimizer
12-
try:
12+
13+
try:
1314
import jax.numpy as jnp
1415
except ImportError:
1516
pass
@@ -83,7 +84,7 @@ class GLMRegressor(GLM, RegressorMixin):
8384
8485
optimizer: object
8586
optimizer, from class nnetsauce.utils.Optimizer
86-
87+
8788
backend: str.
8889
"cpu" or "gpu" or "tpu".
8990

0 commit comments

Comments
 (0)