Skip to content

Commit dde7c0f

Browse files
authored
Merge pull request #297 from Simple-Robotics/fix-deal-cpu-gpu-qpfunction-backward
Handles CPU/GPU transfer in QPFunctionFn's backward function
2 parents 21af3fd + 7661158 commit dde7c0f

File tree

1 file changed

+24
-23
lines changed

1 file changed

+24
-23
lines changed

bindings/python/proxsuite/torch/qplayer.py

Lines changed: 24 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,9 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
114114
if ctx.cpu is not None:
115115
ctx.cpu = max(1, int(ctx.cpu / 2))
116116

117-
zhats = torch.Tensor(nBatch, ctx.nz).type_as(Q)
118-
lams = torch.Tensor(nBatch, ctx.neq).type_as(Q)
119-
nus = torch.Tensor(nBatch, ctx.nineq).type_as(Q)
117+
zhats = torch.empty((nBatch, ctx.nz)).type_as(Q)
118+
lams = torch.empty((nBatch, ctx.neq)).type_as(Q)
119+
nus = torch.empty((nBatch, ctx.nineq)).type_as(Q)
120120

121121
for i in range(nBatch):
122122
qp = ctx.vector_of_qps.init_qp_in_place(ctx.nz, ctx.neq, ctx.nineq)
@@ -163,22 +163,23 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
163163
ctx.vector_of_qps.get(i).solve()
164164

165165
for i in range(nBatch):
166-
zhats[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.x)
167-
lams[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.y)
168-
nus[i] = torch.Tensor(ctx.vector_of_qps.get(i).results.z)
166+
zhats[i] = torch.tensor(ctx.vector_of_qps.get(i).results.x)
167+
lams[i] = torch.tensor(ctx.vector_of_qps.get(i).results.y)
168+
nus[i] = torch.tensor(ctx.vector_of_qps.get(i).results.z)
169169

170170
return zhats, lams, nus
171171

172172
@staticmethod
173173
def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
174+
device = dl_dzhat.device
174175
nBatch, dim, neq, nineq = ctx.nBatch, ctx.nz, ctx.neq, ctx.nineq
175-
dQs = torch.Tensor(nBatch, ctx.nz, ctx.nz)
176-
dps = torch.Tensor(nBatch, ctx.nz)
177-
dGs = torch.Tensor(nBatch, ctx.nineq, ctx.nz)
178-
dus = torch.Tensor(nBatch, ctx.nineq)
179-
dls = torch.Tensor(nBatch, ctx.nineq)
180-
dAs = torch.Tensor(nBatch, ctx.neq, ctx.nz)
181-
dbs = torch.Tensor(nBatch, ctx.neq)
176+
dQs = torch.empty(nBatch, ctx.nz, ctx.nz, device=device)
177+
dps = torch.empty(nBatch, ctx.nz, device=device)
178+
dGs = torch.empty(nBatch, ctx.nineq, ctx.nz, device=device)
179+
dus = torch.empty(nBatch, ctx.nineq, device=device)
180+
dls = torch.empty(nBatch, ctx.nineq, device=device)
181+
dAs = torch.empty(nBatch, ctx.neq, ctx.nz, device=device)
182+
dbs = torch.empty(nBatch, ctx.neq, device=device)
182183

183184
ctx.cpu = os.cpu_count()
184185
if ctx.cpu is not None:
@@ -211,11 +212,11 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
211212
else:
212213
for i in range(nBatch):
213214
rhs = np.zeros(n_tot)
214-
rhs[:dim] = dl_dzhat[i]
215+
rhs[:dim] = dl_dzhat[i].cpu()
215216
if dl_dlams != None:
216-
rhs[dim : dim + neq] = dl_dlams[i]
217+
rhs[dim : dim + neq] = dl_dlams[i].cpu()
217218
if dl_dnus != None:
218-
rhs[dim + neq :] = dl_dnus[i]
219+
rhs[dim + neq :] = dl_dnus[i].cpu()
219220
qpi = ctx.vector_of_qps.get(i)
220221
proxsuite.proxqp.dense.compute_backward(
221222
qp=qpi,
@@ -226,25 +227,25 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
226227
)
227228

228229
for i in range(nBatch):
229-
dQs[i] = torch.Tensor(
230+
dQs[i] = torch.tensor(
230231
ctx.vector_of_qps.get(i).model.backward_data.dL_dH
231232
)
232-
dps[i] = torch.Tensor(
233+
dps[i] = torch.tensor(
233234
ctx.vector_of_qps.get(i).model.backward_data.dL_dg
234235
)
235-
dGs[i] = torch.Tensor(
236+
dGs[i] = torch.tensor(
236237
ctx.vector_of_qps.get(i).model.backward_data.dL_dC
237238
)
238-
dus[i] = torch.Tensor(
239+
dus[i] = torch.tensor(
239240
ctx.vector_of_qps.get(i).model.backward_data.dL_du
240241
)
241-
dls[i] = torch.Tensor(
242+
dls[i] = torch.tensor(
242243
ctx.vector_of_qps.get(i).model.backward_data.dL_dl
243244
)
244-
dAs[i] = torch.Tensor(
245+
dAs[i] = torch.tensor(
245246
ctx.vector_of_qps.get(i).model.backward_data.dL_dA
246247
)
247-
dbs[i] = torch.Tensor(
248+
dbs[i] = torch.tensor(
248249
ctx.vector_of_qps.get(i).model.backward_data.dL_db
249250
)
250251

0 commit comments

Comments
 (0)