Skip to content

Commit 9598598

Browse files
committed
revert
1 parent 1331f06 commit 9598598

File tree

3 files changed

+54
-54
lines changed

3 files changed

+54
-54
lines changed

src/nlp/batch/foreach.jl

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -122,12 +122,12 @@ batch_cons_nln(bnlp::ForEachBatchNLPModel, xs) =
122122
_batch_map(cons_nln, bnlp, xs)
123123
batch_jac(bnlp::ForEachBatchNLPModel, xs) =
124124
_batch_map(jac, bnlp, xs)
125-
batch_jac_lin(bnlp::ForEachBatchNLPModel) =
126-
_batch_map(jac_lin, bnlp)
125+
batch_jac_lin(bnlp::ForEachBatchNLPModel, xs) =
126+
_batch_map(jac_lin, bnlp, xs)
127127
batch_jac_nln(bnlp::ForEachBatchNLPModel, xs) =
128128
_batch_map(jac_nln, bnlp, xs)
129-
batch_jac_lin_coord(bnlp::ForEachBatchNLPModel) =
130-
_batch_map(jac_lin_coord, bnlp)
129+
batch_jac_lin_coord(bnlp::ForEachBatchNLPModel, xs) =
130+
_batch_map(jac_lin_coord, bnlp, xs)
131131
batch_jac_coord(bnlp::ForEachBatchNLPModel, xs) =
132132
_batch_map(jac_coord, bnlp, xs)
133133
batch_jac_nln_coord(bnlp::ForEachBatchNLPModel, xs) =
@@ -146,10 +146,10 @@ batch_jprod_nln(bnlp::ForEachBatchNLPModel, xs, vs) =
146146
_batch_map(jprod_nln, bnlp, xs, vs)
147147
batch_jtprod_nln(bnlp::ForEachBatchNLPModel, xs, vs) =
148148
_batch_map(jtprod_nln, bnlp, xs, vs)
149-
batch_jprod_lin(bnlp::ForEachBatchNLPModel, vs) =
150-
_batch_map(jprod_lin, bnlp, vs)
151-
batch_jtprod_lin(bnlp::ForEachBatchNLPModel, vs) =
152-
_batch_map(jtprod_lin, bnlp, vs)
149+
batch_jprod_lin(bnlp::ForEachBatchNLPModel, xs, vs) =
150+
_batch_map(jprod_lin, bnlp, xs, vs)
151+
batch_jtprod_lin(bnlp::ForEachBatchNLPModel, xs, vs) =
152+
_batch_map(jtprod_lin, bnlp, xs, vs)
153153
batch_ghjvprod(bnlp::ForEachBatchNLPModel, xs, gs, vs) =
154154
_batch_map(ghjvprod, bnlp, xs, gs, vs)
155155

@@ -161,8 +161,8 @@ batch_jac_nln_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) =
161161
_batch_map!(jac_nln_structure!, bnlp, rowss, colss)
162162
batch_hess_structure!(bnlp::ForEachBatchNLPModel, rowss, colss) =
163163
_batch_map!(hess_structure!, bnlp, rowss, colss)
164-
batch_jac_lin_coord!(bnlp::ForEachBatchNLPModel, valss) =
165-
_batch_map!(jac_lin_coord!, bnlp, valss)
164+
batch_jac_lin_coord!(bnlp::ForEachBatchNLPModel, xs, valss) =
165+
_batch_map!(jac_lin_coord!, bnlp, xs, valss)
166166
batch_grad!(bnlp::ForEachBatchNLPModel, xs, gs) =
167167
_batch_map!(grad!, bnlp, xs, gs)
168168
batch_cons!(bnlp::ForEachBatchNLPModel, xs, cs) =
@@ -183,10 +183,10 @@ batch_jprod_nln!(bnlp::ForEachBatchNLPModel, xs, vs, Jvs) =
183183
_batch_map!(jprod_nln!, bnlp, xs, vs, Jvs)
184184
batch_jtprod_nln!(bnlp::ForEachBatchNLPModel, xs, vs, Jtvs) =
185185
_batch_map!(jtprod_nln!, bnlp, xs, vs, Jtvs)
186-
batch_jprod_lin!(bnlp::ForEachBatchNLPModel, vs, Jvs) =
187-
_batch_map!(jprod_lin!, bnlp, vs, Jvs)
188-
batch_jtprod_lin!(bnlp::ForEachBatchNLPModel, vs, Jtvs) =
189-
_batch_map!(jtprod_lin!, bnlp, vs, Jtvs)
186+
batch_jprod_lin!(bnlp::ForEachBatchNLPModel, xs, vs, Jvs) =
187+
_batch_map!(jprod_lin!, bnlp, xs, vs, Jvs)
188+
batch_jtprod_lin!(bnlp::ForEachBatchNLPModel, xs, vs, Jtvs) =
189+
_batch_map!(jtprod_lin!, bnlp, xs, vs, Jtvs)
190190
batch_ghjvprod!(bnlp::ForEachBatchNLPModel, xs, gs, vs, gHvs) =
191191
_batch_map!(ghjvprod!, bnlp, xs, gs, vs, gHvs)
192192

@@ -246,15 +246,15 @@ batch_hess(bnlp::ForEachBatchNLPModel, xs, ys; obj_weights) =
246246
## operators
247247
batch_jac_op(bnlp::ForEachBatchNLPModel, xs) =
248248
_batch_map(jac_op, bnlp, xs)
249-
batch_jac_lin_op(bnlp::ForEachBatchNLPModel) =
250-
_batch_map(jac_lin_op, bnlp)
249+
batch_jac_lin_op(bnlp::ForEachBatchNLPModel, xs) =
250+
_batch_map(jac_lin_op, bnlp, xs)
251251
batch_jac_nln_op(bnlp::ForEachBatchNLPModel, xs) =
252252
_batch_map(jac_nln_op, bnlp, xs)
253253

254254
batch_jac_op!(bnlp::ForEachBatchNLPModel, xs, Jvs, Jtvs) =
255255
_batch_map(jac_op!, bnlp, xs, Jvs, Jtvs)
256-
batch_jac_lin_op!(bnlp::ForEachBatchNLPModel, Jvs, Jtvs) =
257-
_batch_map(jac_lin_op!, bnlp, Jvs, Jtvs)
256+
batch_jac_lin_op!(bnlp::ForEachBatchNLPModel, xs, Jvs, Jtvs) =
257+
_batch_map(jac_lin_op!, bnlp, xs, Jvs, Jtvs)
258258
batch_jac_nln_op!(bnlp::ForEachBatchNLPModel, xs, Jvs, Jtvs) =
259259
_batch_map(jac_nln_op!, bnlp, xs, Jvs, Jtvs)
260260

src/nlp/batch/inplace.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -132,12 +132,12 @@ batch_cons_nln(bnlp::InplaceBatchNLPModel, xs) =
132132
_batch_map(cons_nln, bnlp, xs)
133133
batch_jac(bnlp::InplaceBatchNLPModel, xs) =
134134
_batch_map(jac, bnlp, xs)
135-
batch_jac_lin(bnlp::InplaceBatchNLPModel) =
136-
_batch_map(jac_lin, bnlp)
135+
batch_jac_lin(bnlp::InplaceBatchNLPModel, xs) =
136+
_batch_map(jac_lin, bnlp, xs)
137137
batch_jac_nln(bnlp::InplaceBatchNLPModel, xs) =
138138
_batch_map(jac_nln, bnlp, xs)
139-
batch_jac_lin_coord(bnlp::InplaceBatchNLPModel) =
140-
_batch_map(jac_lin_coord, bnlp)
139+
batch_jac_lin_coord(bnlp::InplaceBatchNLPModel, xs) =
140+
_batch_map(jac_lin_coord, bnlp, xs)
141141
batch_jac_coord(bnlp::InplaceBatchNLPModel, xs) =
142142
_batch_map(jac_coord, bnlp, xs)
143143
batch_jac_nln_coord(bnlp::InplaceBatchNLPModel, xs) =
@@ -156,10 +156,10 @@ batch_jprod_nln(bnlp::InplaceBatchNLPModel, xs, vs) =
156156
_batch_map(jprod_nln, bnlp, xs, vs)
157157
batch_jtprod_nln(bnlp::InplaceBatchNLPModel, xs, vs) =
158158
_batch_map(jtprod_nln, bnlp, xs, vs)
159-
batch_jprod_lin(bnlp::InplaceBatchNLPModel, vs) =
160-
_batch_map(jprod_lin, bnlp, vs)
161-
batch_jtprod_lin(bnlp::InplaceBatchNLPModel, vs) =
162-
_batch_map(jtprod_lin, bnlp, vs)
159+
batch_jprod_lin(bnlp::InplaceBatchNLPModel, xs, vs) =
160+
_batch_map(jprod_lin, bnlp, xs, vs)
161+
batch_jtprod_lin(bnlp::InplaceBatchNLPModel, xs, vs) =
162+
_batch_map(jtprod_lin, bnlp, xs, vs)
163163
batch_ghjvprod(bnlp::InplaceBatchNLPModel, xs, gs, vs) =
164164
_batch_map(ghjvprod, bnlp, xs, gs, vs)
165165

@@ -171,8 +171,8 @@ batch_jac_nln_structure!(bnlp::InplaceBatchNLPModel, rowss, colss) =
171171
_batch_map!(jac_nln_structure!, bnlp, rowss, colss)
172172
batch_hess_structure!(bnlp::InplaceBatchNLPModel, rowss, colss) =
173173
_batch_map!(hess_structure!, bnlp, rowss, colss)
174-
batch_jac_lin_coord!(bnlp::InplaceBatchNLPModel, valss) =
175-
_batch_map!(jac_lin_coord!, bnlp, valss)
174+
batch_jac_lin_coord!(bnlp::InplaceBatchNLPModel, xs, valss) =
175+
_batch_map!(jac_lin_coord!, bnlp, xs, valss)
176176
batch_grad!(bnlp::InplaceBatchNLPModel, xs, gs) =
177177
_batch_map!(grad!, bnlp, xs, gs)
178178
batch_cons!(bnlp::InplaceBatchNLPModel, xs, cs) =
@@ -193,10 +193,10 @@ batch_jprod_nln!(bnlp::InplaceBatchNLPModel, xs, vs, Jvs) =
193193
_batch_map!(jprod_nln!, bnlp, xs, vs, Jvs)
194194
batch_jtprod_nln!(bnlp::InplaceBatchNLPModel, xs, vs, Jtvs) =
195195
_batch_map!(jtprod_nln!, bnlp, xs, vs, Jtvs)
196-
batch_jprod_lin!(bnlp::InplaceBatchNLPModel, vs, Jvs) =
197-
_batch_map!(jprod_lin!, bnlp, vs, Jvs)
198-
batch_jtprod_lin!(bnlp::InplaceBatchNLPModel, vs, Jtvs) =
199-
_batch_map!(jtprod_lin!, bnlp, vs, Jtvs)
196+
batch_jprod_lin!(bnlp::InplaceBatchNLPModel, xs, vs, Jvs) =
197+
_batch_map!(jprod_lin!, bnlp, xs, vs, Jvs)
198+
batch_jtprod_lin!(bnlp::InplaceBatchNLPModel, xs, vs, Jtvs) =
199+
_batch_map!(jtprod_lin!, bnlp, xs, vs, Jtvs)
200200
batch_ghjvprod!(bnlp::InplaceBatchNLPModel, xs, gs, vs, gHvs) =
201201
_batch_map!(ghjvprod!, bnlp, xs, gs, vs, gHvs)
202202

@@ -251,11 +251,11 @@ batch_hess(bnlp::InplaceBatchNLPModel, xs, ys; obj_weights) =
251251

252252
## operators
253253
batch_jac_op(bnlp::InplaceBatchNLPModel, xs) = _inplace_operator_error()
254-
batch_jac_lin_op(bnlp::InplaceBatchNLPModel) = _inplace_operator_error()
254+
batch_jac_lin_op(bnlp::InplaceBatchNLPModel, xs) = _inplace_operator_error()
255255
batch_jac_nln_op(bnlp::InplaceBatchNLPModel, xs) = _inplace_operator_error()
256256

257257
batch_jac_op!(bnlp::InplaceBatchNLPModel, xs, Jvs, Jtvs) = _inplace_operator_error()
258-
batch_jac_lin_op!(bnlp::InplaceBatchNLPModel, Jvs, Jtvs) = _inplace_operator_error()
258+
batch_jac_lin_op!(bnlp::InplaceBatchNLPModel, xs, Jvs, Jtvs) = _inplace_operator_error()
259259
batch_jac_nln_op!(bnlp::InplaceBatchNLPModel, xs, Jvs, Jtvs) = _inplace_operator_error()
260260

261261
## tuple functions

test/nlp/batch_api.jl

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -127,20 +127,20 @@
127127
@test jac_coords manual_jac_coords
128128

129129
# Test batch_jac_lin
130-
batch_jac_lins = batch_jac_lin(bnlp)
131-
manual_jac_lins = [jac_lin(models[i]) for i = 1:n_models]
130+
batch_jac_lins = batch_jac_lin(bnlp, xs)
131+
manual_jac_lins = [jac_lin(models[i], xs[i]) for i = 1:n_models]
132132
@test batch_jac_lins manual_jac_lins
133133

134134
# Test batch_jac_lin_coord
135-
batch_jac_lin_coords = batch_jac_lin_coord(bnlp)
136-
manual_jac_lin_coords = [jac_lin_coord(models[i]) for i = 1:n_models]
135+
batch_jac_lin_coords = batch_jac_lin_coord(bnlp, xs)
136+
manual_jac_lin_coords = [jac_lin_coord(models[i], xs[i]) for i = 1:n_models]
137137
@test batch_jac_lin_coords manual_jac_lin_coords
138138

139139
# Test batch_jac_lin_coord!
140140
jac_lin_coords = [zeros(meta.lin_nnzj) for _ = 1:n_models]
141-
batch_jac_lin_coord!(bnlp, jac_lin_coords)
141+
batch_jac_lin_coord!(bnlp, xs, jac_lin_coords)
142142
manual_jac_lin_coords =
143-
[jac_lin_coord!(models[i], zeros(meta.lin_nnzj)) for i = 1:n_models]
143+
[jac_lin_coord!(models[i], xs[i], zeros(meta.lin_nnzj)) for i = 1:n_models]
144144
@test jac_lin_coords manual_jac_lin_coords
145145

146146
# Test batch_jac_nln
@@ -183,26 +183,26 @@
183183
@test jtprods manual_jtprods
184184

185185
# Test batch_jprod_lin
186-
batch_jprod_lins = batch_jprod_lin(bnlp, vs)
187-
manual_jprod_lins = [jprod_lin(models[i], vs[i]) for i = 1:n_models]
186+
batch_jprod_lins = batch_jprod_lin(bnlp, xs, vs)
187+
manual_jprod_lins = [jprod_lin(models[i], xs[i], vs[i]) for i = 1:n_models]
188188
@test batch_jprod_lins manual_jprod_lins
189189

190190
# Test batch_jprod_lin!
191191
jprod_lins = [zeros(meta.nlin) for _ = 1:n_models]
192-
batch_jprod_lin!(bnlp, vs, jprod_lins)
193-
manual_jprod_lins = [jprod_lin!(models[i], vs[i], zeros(meta.nlin)) for i = 1:n_models]
192+
batch_jprod_lin!(bnlp, xs, vs, jprod_lins)
193+
manual_jprod_lins = [jprod_lin!(models[i], xs[i], vs[i], zeros(meta.nlin)) for i = 1:n_models]
194194
@test jprod_lins manual_jprod_lins
195195

196196
# Test batch_jtprod_lin
197197
ws_lin = [ws[i][1:(meta.nlin)] for i = 1:n_models]
198-
batch_jtprod_lins = batch_jtprod_lin(bnlp, ws_lin)
199-
manual_jtprod_lins = [jtprod_lin(models[i], ws_lin[i]) for i = 1:n_models]
198+
batch_jtprod_lins = batch_jtprod_lin(bnlp, xs, ws_lin)
199+
manual_jtprod_lins = [jtprod_lin(models[i], xs[i], ws_lin[i]) for i = 1:n_models]
200200
@test batch_jtprod_lins manual_jtprod_lins
201201

202202
# Test batch_jtprod_lin!
203203
jtprod_lins = [zeros(n) for _ = 1:n_models]
204-
batch_jtprod_lin!(bnlp, ws_lin, jtprod_lins)
205-
manual_jtprod_lins = [jtprod_lin!(models[i], ws_lin[i], zeros(n)) for i = 1:n_models]
204+
batch_jtprod_lin!(bnlp, xs, ws_lin, jtprod_lins)
205+
manual_jtprod_lins = [jtprod_lin!(models[i], xs[i], ws_lin[i], zeros(n)) for i = 1:n_models]
206206
@test jtprod_lins manual_jtprod_lins
207207

208208
# Test batch_jprod_nln
@@ -432,8 +432,8 @@
432432
end
433433

434434
# Test batch_jac_lin_op
435-
batch_jac_lin_ops = batch_jac_lin_op(bnlp)
436-
manual_jac_lin_ops = [jac_lin_op(models[i]) for i = 1:n_models]
435+
batch_jac_lin_ops = batch_jac_lin_op(bnlp, xs)
436+
manual_jac_lin_ops = [jac_lin_op(models[i], xs[i]) for i = 1:n_models]
437437
ws_lin_vec = ws[1][1:(meta.nlin)]
438438
for i = 1:n_models
439439
@test batch_jac_lin_ops[i] * vs[i] manual_jac_lin_ops[i] * vs[i]
@@ -443,9 +443,9 @@
443443
# Test batch_jac_lin_op!
444444
jvs_lin = [zeros(meta.nlin) for _ = 1:n_models]
445445
jtvs_lin = [zeros(n) for _ = 1:n_models]
446-
batch_jac_lin_ops = batch_jac_lin_op!(bnlp, jvs_lin, jtvs_lin)
446+
batch_jac_lin_ops = batch_jac_lin_op!(bnlp, xs, jvs_lin, jtvs_lin)
447447
manual_jac_lin_ops =
448-
[jac_lin_op!(models[i], zeros(meta.nlin), zeros(n)) for i = 1:n_models]
448+
[jac_lin_op!(models[i], xs[i], zeros(meta.nlin), zeros(n)) for i = 1:n_models]
449449
for i = 1:n_models
450450
@test batch_jac_lin_ops[i] * vs[i] manual_jac_lin_ops[i] * vs[i]
451451
@test batch_jac_lin_ops[i]' * ws_lin_vec manual_jac_lin_ops[i]' * ws_lin_vec
@@ -474,8 +474,8 @@
474474
@test_throws ErrorException batch_jac_op(bnlp, xs)
475475
@test_throws ErrorException batch_jac_op!(bnlp, xs, [zeros(m) for _ = 1:n_models],
476476
[zeros(n) for _ = 1:n_models])
477-
@test_throws ErrorException batch_jac_lin_op(bnlp)
478-
@test_throws ErrorException batch_jac_lin_op!(bnlp,
477+
@test_throws ErrorException batch_jac_lin_op(bnlp, xs)
478+
@test_throws ErrorException batch_jac_lin_op!(bnlp, xs,
479479
[zeros(meta.nlin) for _ = 1:n_models],
480480
[zeros(n) for _ = 1:n_models])
481481
@test_throws ErrorException batch_jac_nln_op(bnlp, xs)

0 commit comments

Comments
 (0)