@@ -256,7 +256,7 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
256256 class QPFunctionFn_infeas (Function ):
257257 @staticmethod
258258 def forward (ctx , Q_ , p_ , A_ , b_ , G_ , l_ , u_ ):
259- n_in , nz = G_ .size () # true double-sided inequality size
259+ n_in , nz = G_ .size () # true double-sided inequality size
260260 nBatch = extract_nBatch (Q_ , p_ , A_ , b_ , G_ , l_ , u_ )
261261
262262 Q , _ = expandParam (Q_ , nBatch , 3 )
@@ -277,7 +277,9 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
277277
278278 zhats = torch .empty ((nBatch , ctx .nz ), dtype = Q .dtype )
279279 nus = torch .empty ((nBatch , ctx .nineq ), dtype = Q .dtype )
280- nus_sol = torch .empty ((nBatch , n_in ), dtype = Q .dtype ) # double-sided inequality multiplier
280+ nus_sol = torch .empty (
281+ (nBatch , n_in ), dtype = Q .dtype
282+ ) # double-sided inequality multiplier
281283 lams = (
282284 torch .empty (nBatch , ctx .neq , dtype = Q .dtype )
283285 if ctx .neq > 0
@@ -289,7 +291,9 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
289291 else torch .empty ()
290292 )
291293 slacks = torch .empty ((nBatch , ctx .nineq ), dtype = Q .dtype )
292- s_i = torch .empty ((nBatch , n_in ), dtype = Q .dtype ) # this one is of size the one of the original n_in
294+ s_i = torch .empty (
295+ (nBatch , n_in ), dtype = Q .dtype
296+ ) # this one is of size the one of the original n_in
293297
294298 vector_of_qps = proxsuite .proxqp .dense .BatchQP ()
295299
@@ -342,17 +346,23 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
342346
343347 for i in range (nBatch ):
344348 zhats [i ] = torch .tensor (vector_of_qps .get (i ).results .x )
345- if nineq > 0 :
349+ if nineq > 0 :
346350 # we re-convert the solution to a double sided inequality QP
347351 slack = - h [i ] + G [i ] @ vector_of_qps .get (i ).results .x
348- nus_sol [i ] = torch .Tensor (- vector_of_qps .get (i ).results .z [:n_in ]+ vector_of_qps .get (i ).results .z [n_in :]) # de-projecting this one may provoke loss of information when using inexact solution
352+ nus_sol [i ] = torch .Tensor (
353+ - vector_of_qps .get (i ).results .z [:n_in ]
354+ + vector_of_qps .get (i ).results .z [n_in :]
355+ ) # de-projecting this one may provoke loss of information when using inexact solution
349356 nus [i ] = torch .tensor (vector_of_qps .get (i ).results .z )
350357 slacks [i ] = slack .clone ().detach ()
351- s_i [i ] = torch .tensor (- vector_of_qps .get (i ).results .si [:n_in ]+ vector_of_qps .get (i ).results .si [n_in :])
358+ s_i [i ] = torch .tensor (
359+ - vector_of_qps .get (i ).results .si [:n_in ]
360+ + vector_of_qps .get (i ).results .si [n_in :]
361+ )
352362 if neq > 0 :
353363 lams [i ] = torch .tensor (vector_of_qps .get (i ).results .y )
354364 s_e [i ] = torch .tensor (vector_of_qps .get (i ).results .se )
355-
365+
356366 ctx .lams = lams
357367 ctx .nus = nus
358368 ctx .slacks = slacks
@@ -377,7 +387,7 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
377387
378388 neq , nineq = ctx .neq , ctx .nineq
379389 # true size
380- n_in_sol = int (nineq / 2 )
390+ n_in_sol = int (nineq / 2 )
381391 dx = torch .zeros ((nBatch , Q .shape [1 ]))
382392 dnu = None
383393 b_5 = None
@@ -464,26 +474,34 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
464474 rhs = np .zeros (kkt .shape [0 ])
465475 rhs [:dim ] = - dl_dzhat [i ]
466476 if dl_dlams != None :
467- if n_eq != 0 :
477+ if n_eq != 0 :
468478 rhs [dim : dim + n_eq ] = - dl_dlams [i ]
469- active_set = None
470- if n_in != 0 :
471- active_set = - z_i [:n_in_sol ]+ z_i [n_in_sol :] >= 0
479+ active_set = None
480+ if n_in != 0 :
481+ active_set = - z_i [:n_in_sol ] + z_i [n_in_sol :] >= 0
472482 if dl_dnus != None :
473- if n_in != 0 :
483+ if n_in != 0 :
474484 # we must convert dl_dnus to a uni sided version
475- # to do so we reconstitute the active set
476- rhs [dim + n_eq : dim + n_eq + n_in_sol ][~ active_set ] = dl_dnus [i ][~ active_set ]
477- rhs [dim + n_eq + n_in_sol : dim + n_eq + n_in ][active_set ] = - dl_dnus [i ][active_set ]
485+ # to do so we reconstitute the active set
486+ rhs [dim + n_eq : dim + n_eq + n_in_sol ][~ active_set ] = dl_dnus [
487+ i
488+ ][~ active_set ]
489+ rhs [dim + n_eq + n_in_sol : dim + n_eq + n_in ][active_set ] = (
490+ - dl_dnus [i ][active_set ]
491+ )
478492 if dl_ds_e != None :
479493 if dl_ds_e .shape [0 ] != 0 :
480494 rhs [dim + n_eq + n_in : dim + 2 * n_eq + n_in ] = - dl_ds_e [i ]
481495 if dl_ds_i != None :
482496 if dl_ds_i .shape [0 ] != 0 :
483497 # we must convert dl_dnus to a uni sided version
484- # to do so we reconstitute the active set
485- rhs [dim + 2 * n_eq + n_in : dim + 2 * n_eq + n_in + n_in_sol ][~ active_set ] = dl_ds_i [i ][~ active_set ]
486- rhs [dim + 2 * n_eq + n_in + n_in_sol :][active_set ] = - dl_ds_i [i ][active_set ]
498+ # to do so we reconstitute the active set
499+ rhs [dim + 2 * n_eq + n_in : dim + 2 * n_eq + n_in + n_in_sol ][
500+ ~ active_set
501+ ] = dl_ds_i [i ][~ active_set ]
502+ rhs [dim + 2 * n_eq + n_in + n_in_sol :][active_set ] = - dl_ds_i [
503+ i
504+ ][active_set ]
487505
488506 l = np .zeros (0 )
489507 u = np .zeros (0 )
@@ -580,7 +598,15 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
580598 if p_e :
581599 dps = dps .mean (0 )
582600
583- grads = (dQs , dps , dAs , dbs , dGs [n_in_sol :, :], - dhs [:n_in_sol ], dhs [n_in_sol :])
601+ grads = (
602+ dQs ,
603+ dps ,
604+ dAs ,
605+ dbs ,
606+ dGs [n_in_sol :, :],
607+ - dhs [:n_in_sol ],
608+ dhs [n_in_sol :],
609+ )
584610
585611 return grads
586612
0 commit comments