Skip to content

Commit 52d0095

Browse files
authored
Merge pull request #350 from jcarpent/devel
Sync submodule CMake
2 parents 6d708e6 + c66e666 commit 52d0095

File tree

2 files changed

+47
-21
lines changed

2 files changed

+47
-21
lines changed

bindings/python/proxsuite/torch/qplayer.py

Lines changed: 46 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
256256
class QPFunctionFn_infeas(Function):
257257
@staticmethod
258258
def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
259-
n_in, nz = G_.size() # true double-sided inequality size
259+
n_in, nz = G_.size() # true double-sided inequality size
260260
nBatch = extract_nBatch(Q_, p_, A_, b_, G_, l_, u_)
261261

262262
Q, _ = expandParam(Q_, nBatch, 3)
@@ -277,7 +277,9 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
277277

278278
zhats = torch.empty((nBatch, ctx.nz), dtype=Q.dtype)
279279
nus = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)
280-
nus_sol = torch.empty((nBatch, n_in), dtype=Q.dtype) # double-sided inequality multiplier
280+
nus_sol = torch.empty(
281+
(nBatch, n_in), dtype=Q.dtype
282+
) # double-sided inequality multiplier
281283
lams = (
282284
torch.empty(nBatch, ctx.neq, dtype=Q.dtype)
283285
if ctx.neq > 0
@@ -289,7 +291,9 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
289291
else torch.empty()
290292
)
291293
slacks = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)
292-
s_i = torch.empty((nBatch, n_in), dtype=Q.dtype) # this one is of size the one of the original n_in
294+
s_i = torch.empty(
295+
(nBatch, n_in), dtype=Q.dtype
296+
) # this one is of size the one of the original n_in
293297

294298
vector_of_qps = proxsuite.proxqp.dense.BatchQP()
295299

@@ -342,17 +346,23 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
342346

343347
for i in range(nBatch):
344348
zhats[i] = torch.tensor(vector_of_qps.get(i).results.x)
345-
if nineq>0:
349+
if nineq > 0:
346350
# we re-convert the solution to a double sided inequality QP
347351
slack = -h[i] + G[i] @ vector_of_qps.get(i).results.x
348-
nus_sol[i] = torch.Tensor(-vector_of_qps.get(i).results.z[:n_in]+vector_of_qps.get(i).results.z[n_in:]) # de-projecting this one may provoke loss of information when using inexact solution
352+
nus_sol[i] = torch.Tensor(
353+
-vector_of_qps.get(i).results.z[:n_in]
354+
+ vector_of_qps.get(i).results.z[n_in:]
355+
) # de-projecting this one may provoke loss of information when using inexact solution
349356
nus[i] = torch.tensor(vector_of_qps.get(i).results.z)
350357
slacks[i] = slack.clone().detach()
351-
s_i[i] = torch.tensor(-vector_of_qps.get(i).results.si[:n_in]+vector_of_qps.get(i).results.si[n_in:])
358+
s_i[i] = torch.tensor(
359+
-vector_of_qps.get(i).results.si[:n_in]
360+
+ vector_of_qps.get(i).results.si[n_in:]
361+
)
352362
if neq > 0:
353363
lams[i] = torch.tensor(vector_of_qps.get(i).results.y)
354364
s_e[i] = torch.tensor(vector_of_qps.get(i).results.se)
355-
365+
356366
ctx.lams = lams
357367
ctx.nus = nus
358368
ctx.slacks = slacks
@@ -377,7 +387,7 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
377387

378388
neq, nineq = ctx.neq, ctx.nineq
379389
# true size
380-
n_in_sol = int(nineq/2)
390+
n_in_sol = int(nineq / 2)
381391
dx = torch.zeros((nBatch, Q.shape[1]))
382392
dnu = None
383393
b_5 = None
@@ -464,26 +474,34 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
464474
rhs = np.zeros(kkt.shape[0])
465475
rhs[:dim] = -dl_dzhat[i]
466476
if dl_dlams != None:
467-
if n_eq!= 0:
477+
if n_eq != 0:
468478
rhs[dim : dim + n_eq] = -dl_dlams[i]
469-
active_set = None
470-
if n_in!=0:
471-
active_set = -z_i[:n_in_sol]+z_i[n_in_sol:] >= 0
479+
active_set = None
480+
if n_in != 0:
481+
active_set = -z_i[:n_in_sol] + z_i[n_in_sol:] >= 0
472482
if dl_dnus != None:
473-
if n_in !=0:
483+
if n_in != 0:
474484
# we must convert dl_dnus to a uni sided version
475-
# to do so we reconstitute the active set
476-
rhs[dim + n_eq : dim + n_eq + n_in_sol][~active_set] = dl_dnus[i][~active_set]
477-
rhs[dim + n_eq + n_in_sol: dim + n_eq + n_in][active_set] = -dl_dnus[i][active_set]
485+
# to do so we reconstitute the active set
486+
rhs[dim + n_eq : dim + n_eq + n_in_sol][~active_set] = dl_dnus[
487+
i
488+
][~active_set]
489+
rhs[dim + n_eq + n_in_sol : dim + n_eq + n_in][active_set] = (
490+
-dl_dnus[i][active_set]
491+
)
478492
if dl_ds_e != None:
479493
if dl_ds_e.shape[0] != 0:
480494
rhs[dim + n_eq + n_in : dim + 2 * n_eq + n_in] = -dl_ds_e[i]
481495
if dl_ds_i != None:
482496
if dl_ds_i.shape[0] != 0:
483497
# we must convert dl_dnus to a uni sided version
484-
# to do so we reconstitute the active set
485-
rhs[dim + 2 * n_eq + n_in : dim + 2 * n_eq + n_in + n_in_sol][~active_set] = dl_ds_i[i][~active_set]
486-
rhs[dim + 2 * n_eq + n_in + n_in_sol:][active_set] = -dl_ds_i[i][active_set]
498+
# to do so we reconstitute the active set
499+
rhs[dim + 2 * n_eq + n_in : dim + 2 * n_eq + n_in + n_in_sol][
500+
~active_set
501+
] = dl_ds_i[i][~active_set]
502+
rhs[dim + 2 * n_eq + n_in + n_in_sol :][active_set] = -dl_ds_i[
503+
i
504+
][active_set]
487505

488506
l = np.zeros(0)
489507
u = np.zeros(0)
@@ -580,7 +598,15 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
580598
if p_e:
581599
dps = dps.mean(0)
582600

583-
grads = (dQs, dps, dAs, dbs, dGs[n_in_sol:, :], -dhs[:n_in_sol], dhs[n_in_sol:])
601+
grads = (
602+
dQs,
603+
dps,
604+
dAs,
605+
dbs,
606+
dGs[n_in_sol:, :],
607+
-dhs[:n_in_sol],
608+
dhs[n_in_sol:],
609+
)
584610

585611
return grads
586612

0 commit comments

Comments
 (0)