@@ -72,7 +72,7 @@ class GLM(Base):
7272
7373 optimizer: object
7474 optimizer, from class nnetsauce.utils.Optimizer
75-
75+
7676 backend: str.
7777 "cpu" or "gpu" or "tpu".
7878
@@ -132,34 +132,42 @@ def __init__(
132132 self .backend = backend
133133 self .beta_ = None
134134
135- def compute_XB (self , X , beta = None , row_index = None ):
135+ def compute_XB (self , X , beta = None , row_index = None ):
136136 if beta is not None :
137137 if row_index is None :
138138 return mo .safe_sparse_dot (X , beta , backend = self .backend )
139139
140- return mo .safe_sparse_dot (X [row_index , :], beta , backend = self .backend )
140+ return mo .safe_sparse_dot (
141+ X [row_index , :], beta , backend = self .backend
142+ )
141143
142144 # self.beta_ is None in this case
143145 if row_index is None :
144146 return mo .safe_sparse_dot (X , self .beta_ , backend = self .backend )
145147
146- return mo .safe_sparse_dot (X [row_index , :], self .beta_ , backend = self .backend )
148+ return mo .safe_sparse_dot (
149+ X [row_index , :], self .beta_ , backend = self .backend
150+ )
147151
148152 def compute_XB2 (self , X , beta = None , row_index = None ):
149153 def f00 (X ):
150154 return mo .safe_sparse_dot (X , self .beta_ , backend = self .backend )
151155
152156 def f01 (X ):
153- return mo .safe_sparse_dot (X [row_index , :], self .beta_ , backend = self .backend )
157+ return mo .safe_sparse_dot (
158+ X [row_index , :], self .beta_ , backend = self .backend
159+ )
154160
155161 def f11 (X ):
156- return mo .safe_sparse_dot (X [row_index , :], beta , backend = self .backend )
162+ return mo .safe_sparse_dot (
163+ X [row_index , :], beta , backend = self .backend
164+ )
157165
158166 def f10 (X ):
159167 if self .backend != "cpu" :
160168 raise NotImplementedError (
161- "GLM.compute_XB is only implemented for backend='cpu'"
162- )
169+ "GLM.compute_XB is only implemented for backend='cpu'"
170+ )
163171 return mo .safe_sparse_dot (X , beta , backend = self .backend )
164172
165173 h_result = {"00" : f00 , "01" : f01 , "11" : f11 , "10" : f10 }
0 commit comments