Skip to content

Commit 373d5ad

Browse files
committed
[Nonliear] fix performance of evaluating univariate operators
1 parent d389ba1 commit 373d5ad

File tree

5 files changed

+198
-24
lines changed

5 files changed

+198
-24
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,7 +376,7 @@ function _forward_eval_ϵ(
376376
@inbounds child_idx = children_arr[ex.adj.colptr[k]]
377377
f′′ = Nonlinear.eval_univariate_hessian(
378378
user_operators,
379-
user_operators.univariate_operators[node.index],
379+
node.index,
380380
ex.forward_storage[child_idx],
381381
)
382382
partials_storage_ϵ[child_idx] = f′′ * storage_ϵ[child_idx]

src/Nonlinear/ReverseAD/reverse_mode.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -248,16 +248,13 @@ function _forward_eval(
248248
end
249249
elseif node.type == Nonlinear.NODE_CALL_UNIVARIATE
250250
child_idx = children_arr[f.adj.colptr[k]]
251-
f.forward_storage[k] = Nonlinear.eval_univariate_function(
251+
ret_f, ret_f′ = Nonlinear.eval_univariate_function_and_gradient(
252252
operators,
253-
operators.univariate_operators[node.index],
254-
f.forward_storage[child_idx],
255-
)
256-
f.partials_storage[child_idx] = Nonlinear.eval_univariate_gradient(
257-
operators,
258-
operators.univariate_operators[node.index],
253+
node.index,
259254
f.forward_storage[child_idx],
260255
)
256+
f.forward_storage[k] = ret_f
257+
f.partials_storage[child_idx] = ret_f′
261258
elseif node.type == Nonlinear.NODE_COMPARISON
262259
children_idx = SparseArrays.nzrange(f.adj, k)
263260
result = true

src/Nonlinear/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ function evaluate(
377377
child_idx = children_arr[adj.colptr[k]]
378378
storage[k] = eval_univariate_function(
379379
model.operators,
380-
model.operators.univariate_operators[node.index],
380+
node.index,
381381
storage[child_idx],
382382
)
383383
elseif node.type == NODE_COMPARISON

src/Nonlinear/operators.jl

Lines changed: 146 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -517,19 +517,47 @@ end
517517
"""
518518
eval_univariate_function(
519519
registry::OperatorRegistry,
520-
op::Symbol,
520+
op::Union{Symbol,Integer},
521521
x::T,
522522
) where {T}
523523
524524
Evaluate the operator `op(x)::T`, where `op` is a univariate function in
525525
`registry`.
526+
527+
If `op isa Integer`, then `op` is the index in
528+
`registry.univariate_operators[op]`.
529+
530+
## Example
531+
532+
```jldoctest
533+
julia> import MathOptInterface as MOI
534+
535+
julia> r = MOI.Nonlinear.OperatorRegistry();
536+
537+
julia> MOI.Nonlinear.eval_univariate_function(r, :abs, -1.2)
538+
1.2
539+
540+
julia> r.univariate_operators[3]
541+
:abs
542+
543+
julia> MOI.Nonlinear.eval_univariate_function(r, 3, -1.2)
544+
1.2
545+
```
526546
"""
527547
function eval_univariate_function(
528548
registry::OperatorRegistry,
529549
op::Symbol,
530550
x::T,
531551
) where {T}
532552
id = registry.univariate_operator_to_id[op]
553+
return eval_univariate_function(registry, id, x)
554+
end
555+
556+
function eval_univariate_function(
557+
registry::OperatorRegistry,
558+
id::Integer,
559+
x::T,
560+
) where {T}
533561
if id <= registry.univariate_user_operator_start
534562
f, _ = _eval_univariate(id, x)
535563
return f::T
@@ -544,19 +572,47 @@ end
544572
"""
545573
eval_univariate_gradient(
546574
registry::OperatorRegistry,
547-
op::Symbol,
575+
op::Union{Symbol,Integer},
548576
x::T,
549577
) where {T}
550578
551579
Evaluate the first-derivative of the operator `op(x)::T`, where `op` is a
552580
univariate function in `registry`.
581+
582+
If `op isa Integer`, then `op` is the index in
583+
`registry.univariate_operators[op]`.
584+
585+
## Example
586+
587+
```jldoctest
588+
julia> import MathOptInterface as MOI
589+
590+
julia> r = MOI.Nonlinear.OperatorRegistry();
591+
592+
julia> MOI.Nonlinear.eval_univariate_gradient(r, :abs, -1.2)
593+
-1.0
594+
595+
julia> r.univariate_operators[3]
596+
:abs
597+
598+
julia> MOI.Nonlinear.eval_univariate_gradient(r, 3, -1.2)
599+
-1.0
600+
```
553601
"""
554602
function eval_univariate_gradient(
555603
registry::OperatorRegistry,
556604
op::Symbol,
557605
x::T,
558606
) where {T}
559607
id = registry.univariate_operator_to_id[op]
608+
return eval_univariate_gradient(registry, id, x)
609+
end
610+
611+
function eval_univariate_gradient(
612+
registry::OperatorRegistry,
613+
id::Integer,
614+
x::T,
615+
) where {T}
560616
if id <= registry.univariate_user_operator_start
561617
_, f′ = _eval_univariate(id, x)
562618
return f′::T
@@ -568,22 +624,109 @@ function eval_univariate_gradient(
568624
return ret::T
569625
end
570626

627+
"""
628+
eval_univariate_function_and_gradient(
629+
registry::OperatorRegistry,
630+
op::Union{Symbol,Integer},
631+
x::T,
632+
)::Tuple{T,T} where {T}
633+
634+
Evaluate the function and first-derivative of the operator `op(x)::T`, where
635+
`op` is a univariate function in `registry`.
636+
637+
If `op isa Integer`, then `op` is the index in
638+
`registry.univariate_operators[op]`.
639+
640+
## Example
641+
642+
```jldoctest
643+
julia> import MathOptInterface as MOI
644+
645+
julia> r = MOI.Nonlinear.OperatorRegistry();
646+
647+
julia> MOI.Nonlinear.eval_univariate_function_and_gradient(r, :abs, -1.2)
648+
(1.2, -1.0)
649+
650+
julia> r.univariate_operators[3]
651+
:abs
652+
653+
julia> MOI.Nonlinear.eval_univariate_function_and_gradient(r, 3, -1.2)
654+
(1.2, -1.0)
655+
```
656+
"""
657+
function eval_univariate_function_and_gradient(
658+
registry::OperatorRegistry,
659+
op::Symbol,
660+
x::T,
661+
) where {T}
662+
id = registry.univariate_operator_to_id[op]
663+
return eval_univariate_function_and_gradient(registry, id, x)
664+
end
665+
666+
function eval_univariate_function_and_gradient(
667+
registry::OperatorRegistry,
668+
id::Integer,
669+
x::T,
670+
) where {T}
671+
if id <= registry.univariate_user_operator_start
672+
return _eval_univariate(id, x)::Tuple{T,T}
673+
end
674+
offset = id - registry.univariate_user_operator_start
675+
operator = registry.registered_univariate_operators[offset]
676+
ret_f = operator.f(x)
677+
check_return_type(T, ret_f)
678+
ret_f′ = operator.f′(x)
679+
check_return_type(T, ret_f′)
680+
return ret_f::T, ret_f′::T
681+
end
682+
571683
"""
572684
eval_univariate_hessian(
573685
registry::OperatorRegistry,
574-
op::Symbol,
686+
op::Union{Symbol,Integer},
575687
x::T,
576688
) where {T}
577689
578690
Evaluate the second-derivative of the operator `op(x)::T`, where `op` is a
579691
univariate function in `registry`.
692+
693+
If `op isa Integer`, then `op` is the index in
694+
`registry.univariate_operators[op]`.
695+
696+
## Example
697+
698+
```jldoctest
699+
julia> import MathOptInterface as MOI
700+
701+
julia> r = MOI.Nonlinear.OperatorRegistry();
702+
703+
julia> MOI.Nonlinear.eval_univariate_hessian(r, :sin, 1.0)
704+
-0.8414709848078965
705+
706+
julia> r.univariate_operators[16]
707+
:sin
708+
709+
julia> MOI.Nonlinear.eval_univariate_hessian(r, 16, 1.0)
710+
-0.8414709848078965
711+
712+
julia> -sin(1.0)
713+
-0.8414709848078965
714+
```
580715
"""
581716
function eval_univariate_hessian(
582717
registry::OperatorRegistry,
583718
op::Symbol,
584719
x::T,
585720
) where {T}
586721
id = registry.univariate_operator_to_id[op]
722+
return eval_univariate_hessian(registry, id, x)
723+
end
724+
725+
function eval_univariate_hessian(
726+
registry::OperatorRegistry,
727+
id::Integer,
728+
x::T,
729+
) where {T}
587730
if id <= registry.univariate_user_operator_start
588731
return _eval_univariate_2nd_deriv(id, x)::T
589732
end

test/Nonlinear/Nonlinear.jl

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -360,28 +360,62 @@ end
360360

361361
function test_eval_univariate_function()
362362
r = Nonlinear.OperatorRegistry()
363-
@test Nonlinear.eval_univariate_function(r, :+, 1.0) == 1.0
364-
@test Nonlinear.eval_univariate_function(r, :-, 1.0) == -1.0
365-
@test Nonlinear.eval_univariate_function(r, :abs, -1.1) == 1.1
366-
@test Nonlinear.eval_univariate_function(r, :abs, 1.1) == 1.1
363+
for (op, x, y) in [
364+
(:+, 1.0, 1.0),
365+
(:-, 1.0, -1.0),
366+
(:abs, -1.1, 1.1),
367+
(:abs, 1.1, 1.1),
368+
]
369+
id = r.univariate_operator_to_id[op]
370+
@test Nonlinear.eval_univariate_function(r, op, x) == y
371+
@test Nonlinear.eval_univariate_function(r, id, x) == y
372+
end
367373
return
368374
end
369375

370376
function test_eval_univariate_gradient()
371377
r = Nonlinear.OperatorRegistry()
372-
@test Nonlinear.eval_univariate_gradient(r, :+, 1.2) == 1.0
373-
@test Nonlinear.eval_univariate_gradient(r, :-, 1.2) == -1.0
374-
@test Nonlinear.eval_univariate_gradient(r, :abs, -1.1) == -1.0
375-
@test Nonlinear.eval_univariate_gradient(r, :abs, 1.1) == 1.0
378+
for (op, x, y) in [
379+
(:+, 1.2, 1.0),
380+
(:-, 1.2, -1.0),
381+
(:abs, -1.1, -1.0),
382+
(:abs, 1.1, 1.0),
383+
]
384+
id = r.univariate_operator_to_id[op]
385+
@test Nonlinear.eval_univariate_gradient(r, op, x) == y
386+
@test Nonlinear.eval_univariate_gradient(r, id, x) == y
387+
end
388+
return
389+
end
390+
391+
function test_eval_univariate_function_and_gradient()
392+
r = Nonlinear.OperatorRegistry()
393+
for (op, x, y) in [
394+
(:+, 1.2, (1.2, 1.0)),
395+
(:-, 1.2, (-1.2, -1.0)),
396+
(:abs, -1.1, (1.1, -1.0)),
397+
(:abs, 1.1, (1.1, 1.0)),
398+
]
399+
id = r.univariate_operator_to_id[op]
400+
@test Nonlinear.eval_univariate_function_and_gradient(r, op, x) == y
401+
@test Nonlinear.eval_univariate_function_and_gradient(r, id, x) == y
402+
end
376403
return
377404
end
378405

379406
function test_eval_univariate_hessian()
380407
r = Nonlinear.OperatorRegistry()
381-
@test Nonlinear.eval_univariate_hessian(r, :+, 1.2) == 0.0
382-
@test Nonlinear.eval_univariate_hessian(r, :-, 1.2) == 0.0
383-
@test Nonlinear.eval_univariate_hessian(r, :abs, -1.1) == 0.0
384-
@test Nonlinear.eval_univariate_hessian(r, :abs, 1.1) == 0.0
408+
for (op, x, y) in [
409+
(:+, 1.2, 0.0),
410+
(:-, 1.2, 0.0),
411+
(:abs, -1.1, 0.0),
412+
(:abs, 1.1, 0.0),
413+
(:sin, 1.0, -sin(1.0)),
414+
]
415+
id = r.univariate_operator_to_id[op]
416+
@test Nonlinear.eval_univariate_hessian(r, op, x) == y
417+
@test Nonlinear.eval_univariate_hessian(r, id, x) == y
418+
end
385419
return
386420
end
387421

0 commit comments

Comments
 (0)