@@ -114,9 +114,9 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
114114 if ctx .cpu is not None :
115115 ctx .cpu = max (1 , int (ctx .cpu / 2 ))
116116
117- zhats = torch .Tensor ( nBatch , ctx .nz ).type_as (Q )
118- lams = torch .Tensor ( nBatch , ctx .neq ).type_as (Q )
119- nus = torch .Tensor ( nBatch , ctx .nineq ).type_as (Q )
117+ zhats = torch .empty (( nBatch , ctx .nz ) ).type_as (Q )
118+ lams = torch .empty (( nBatch , ctx .neq ) ).type_as (Q )
119+ nus = torch .empty (( nBatch , ctx .nineq ) ).type_as (Q )
120120
121121 for i in range (nBatch ):
122122 qp = ctx .vector_of_qps .init_qp_in_place (ctx .nz , ctx .neq , ctx .nineq )
@@ -163,22 +163,23 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
163163 ctx .vector_of_qps .get (i ).solve ()
164164
165165 for i in range (nBatch ):
166- zhats [i ] = torch .Tensor (ctx .vector_of_qps .get (i ).results .x )
167- lams [i ] = torch .Tensor (ctx .vector_of_qps .get (i ).results .y )
168- nus [i ] = torch .Tensor (ctx .vector_of_qps .get (i ).results .z )
166+ zhats [i ] = torch .tensor (ctx .vector_of_qps .get (i ).results .x )
167+ lams [i ] = torch .tensor (ctx .vector_of_qps .get (i ).results .y )
168+ nus [i ] = torch .tensor (ctx .vector_of_qps .get (i ).results .z )
169169
170170 return zhats , lams , nus
171171
172172 @staticmethod
173173 def backward (ctx , dl_dzhat , dl_dlams , dl_dnus ):
174+ device = dl_dzhat .device
174175 nBatch , dim , neq , nineq = ctx .nBatch , ctx .nz , ctx .neq , ctx .nineq
175- dQs = torch .Tensor (nBatch , ctx .nz , ctx .nz )
176- dps = torch .Tensor (nBatch , ctx .nz )
177- dGs = torch .Tensor (nBatch , ctx .nineq , ctx .nz )
178- dus = torch .Tensor (nBatch , ctx .nineq )
179- dls = torch .Tensor (nBatch , ctx .nineq )
180- dAs = torch .Tensor (nBatch , ctx .neq , ctx .nz )
181- dbs = torch .Tensor (nBatch , ctx .neq )
176+ dQs = torch .empty (nBatch , ctx .nz , ctx .nz , device = device )
177+ dps = torch .empty (nBatch , ctx .nz , device = device )
178+ dGs = torch .empty (nBatch , ctx .nineq , ctx .nz , device = device )
179+ dus = torch .empty (nBatch , ctx .nineq , device = device )
180+ dls = torch .empty (nBatch , ctx .nineq , device = device )
181+ dAs = torch .empty (nBatch , ctx .neq , ctx .nz , device = device )
182+ dbs = torch .empty (nBatch , ctx .neq , device = device )
182183
183184 ctx .cpu = os .cpu_count ()
184185 if ctx .cpu is not None :
@@ -211,11 +212,11 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
211212 else :
212213 for i in range (nBatch ):
213214 rhs = np .zeros (n_tot )
214- rhs [:dim ] = dl_dzhat [i ]
215+ rhs [:dim ] = dl_dzhat [i ]. cpu ()
215216 if dl_dlams != None :
216- rhs [dim : dim + neq ] = dl_dlams [i ]
217+ rhs [dim : dim + neq ] = dl_dlams [i ]. cpu ()
217218 if dl_dnus != None :
218- rhs [dim + neq :] = dl_dnus [i ]
219+ rhs [dim + neq :] = dl_dnus [i ]. cpu ()
219220 qpi = ctx .vector_of_qps .get (i )
220221 proxsuite .proxqp .dense .compute_backward (
221222 qp = qpi ,
@@ -226,25 +227,25 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
226227 )
227228
228229 for i in range (nBatch ):
229- dQs [i ] = torch .Tensor (
230+ dQs [i ] = torch .tensor (
230231 ctx .vector_of_qps .get (i ).model .backward_data .dL_dH
231232 )
232- dps [i ] = torch .Tensor (
233+ dps [i ] = torch .tensor (
233234 ctx .vector_of_qps .get (i ).model .backward_data .dL_dg
234235 )
235- dGs [i ] = torch .Tensor (
236+ dGs [i ] = torch .tensor (
236237 ctx .vector_of_qps .get (i ).model .backward_data .dL_dC
237238 )
238- dus [i ] = torch .Tensor (
239+ dus [i ] = torch .tensor (
239240 ctx .vector_of_qps .get (i ).model .backward_data .dL_du
240241 )
241- dls [i ] = torch .Tensor (
242+ dls [i ] = torch .tensor (
242243 ctx .vector_of_qps .get (i ).model .backward_data .dL_dl
243244 )
244- dAs [i ] = torch .Tensor (
245+ dAs [i ] = torch .tensor (
245246 ctx .vector_of_qps .get (i ).model .backward_data .dL_dA
246247 )
247- dbs [i ] = torch .Tensor (
248+ dbs [i ] = torch .tensor (
248249 ctx .vector_of_qps .get (i ).model .backward_data .dL_db
249250 )
250251
0 commit comments