Skip to content

Commit 40fbb0c

Browse files
committed
simplify inputs/outputs
1 parent 7864819 commit 40fbb0c

File tree

2 files changed

+12
-24
lines changed

2 files changed

+12
-24
lines changed

src/nlp/batch/foreach.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,11 @@ function _batch_map!(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{T,N}) where {F
2929
n = bnlp.batch_size
3030
length(xs) == 0 && error("Cannot call _batch_map! without providing arguments.")
3131
@lencheck_tup n xs
32-
outputs = xs[end]
33-
inputs = length(xs) == 1 ? () : Base.ntuple(i -> xs[i], length(xs) - 1)
34-
@lencheck n outputs
3532
for i = 1:n
36-
args_i = (x[i] for x in inputs)
37-
f(bnlp[i], args_i..., outputs[i])
33+
args_i = (x[i] for x in xs)
34+
f(bnlp[i], args_i...)
3835
end
39-
return outputs
36+
return xs[end]
4037
end
4138

4239
function _batch_map_weight(f::F, bnlp::ForEachBatchNLPModel, obj_weights, xs::Vararg{T,N}) where {F,T,N}
@@ -57,14 +54,11 @@ function _batch_map_weight!(f::F, bnlp::ForEachBatchNLPModel, obj_weights, xs::V
5754
length(xs) == 0 && error("Cannot call _batch_map_weight! without providing arguments.")
5855
@lencheck_tup n xs
5956
@lencheck n obj_weights
60-
outputs = xs[end]
61-
inputs = length(xs) == 1 ? () : Base.ntuple(i -> xs[i], length(xs) - 1)
62-
@lencheck n outputs
6357
for i = 1:n
64-
args_i = (x[i] for x in inputs)
65-
f(bnlp[i], args_i..., outputs[i]; obj_weight = obj_weights[i])
58+
args_i = (x[i] for x in xs)
59+
f(bnlp[i], args_i...; obj_weight = obj_weights[i])
6660
end
67-
return outputs
61+
return xs[end]
6862
end
6963

7064
function _batch_map_tuple(f::F, bnlp::ForEachBatchNLPModel, xs::Vararg{T,N}) where {F,T,N}

src/nlp/batch/inplace.jl

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,12 @@ function _batch_map!(f::F, bnlp::InplaceBatchNLPModel, xs::Vararg{T,N}) where {F
3535
n = bnlp.batch_size
3636
length(xs) == 0 && error("Cannot call _batch_map! without providing arguments.")
3737
@lencheck_tup n xs
38-
outputs = xs[end]
39-
inputs = length(xs) == 1 ? () : Base.ntuple(i -> xs[i], length(xs) - 1)
40-
@lencheck n outputs
4138
for i = 1:n
42-
args_i = (x[i] for x in inputs)
39+
args_i = (x[i] for x in xs)
4340
bnlp.updates[i](bnlp.base_model) # call update function
44-
f(bnlp.base_model, args_i..., outputs[i])
41+
f(bnlp.base_model, args_i...)
4542
end
46-
return outputs
43+
return xs[end]
4744
end
4845

4946
function _batch_map_weight(f::F, bnlp::InplaceBatchNLPModel, obj_weights, xs::Vararg{T,N}) where {F,T,N}
@@ -65,15 +62,12 @@ function _batch_map_weight!(f::F, bnlp::InplaceBatchNLPModel, obj_weights, xs::V
6562
length(xs) == 0 && error("_batch_map_weight! with zero args")
6663
@lencheck_tup n xs
6764
@lencheck n obj_weights
68-
outputs = xs[end]
69-
inputs = length(xs) == 1 ? () : Base.ntuple(i -> xs[i], length(xs) - 1)
70-
@lencheck n outputs
7165
for i = 1:n
72-
args_i = (x[i] for x in inputs)
66+
args_i = (x[i] for x in xs)
7367
bnlp.updates[i](bnlp.base_model) # call update function
74-
f(bnlp.base_model, args_i..., outputs[i]; obj_weight = obj_weights[i])
68+
f(bnlp.base_model, args_i...; obj_weight = obj_weights[i])
7569
end
76-
return outputs
70+
return xs[end]
7771
end
7872

7973
function _batch_map_tuple(f::F, bnlp::InplaceBatchNLPModel, xs::Vararg{T,N}) where {F,T,N}

0 commit comments

Comments
 (0)