Skip to content

Commit eb6186c

Browse files
committed
Update
1 parent 8948063 commit eb6186c

File tree

2 files changed

+23
-24
lines changed

2 files changed

+23
-24
lines changed

src/Nonlinear/ReverseAD/forward_over_reverse.jl

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -108,31 +108,16 @@ function _eval_hessian_chunk(
108108
end
109109

110110
# A wrapper function to avoid dynamic dispatch.
111-
function _hessian_slice_inner(d, ex, chunk::Int)
112-
@assert 1 <= chunk <= MAX_CHUNK
113-
@assert MAX_CHUNK == 10
114-
if chunk == 1
115-
_hessian_slice_inner(d, ex, ForwardDiff.Partials{1,Float64})
116-
elseif chunk == 2
117-
_hessian_slice_inner(d, ex, ForwardDiff.Partials{2,Float64})
118-
elseif chunk == 3
119-
_hessian_slice_inner(d, ex, ForwardDiff.Partials{3,Float64})
120-
elseif chunk == 4
121-
_hessian_slice_inner(d, ex, ForwardDiff.Partials{4,Float64})
122-
elseif chunk == 5
123-
_hessian_slice_inner(d, ex, ForwardDiff.Partials{5,Float64})
124-
elseif chunk == 6
125-
_hessian_slice_inner(d, ex, ForwardDiff.Partials{6,Float64})
126-
elseif chunk == 7
127-
_hessian_slice_inner(d, ex, ForwardDiff.Partials{7,Float64})
128-
elseif chunk == 8
129-
_hessian_slice_inner(d, ex, ForwardDiff.Partials{8,Float64})
130-
elseif chunk == 9
131-
_hessian_slice_inner(d, ex, ForwardDiff.Partials{9,Float64})
132-
else
133-
_hessian_slice_inner(d, ex, ForwardDiff.Partials{10,Float64})
111+
function _generate_hessian_slice_inner()
112+
exprs = map(1:MAX_CHUNK) do id
113+
return :(_hessian_slice_inner(d, ex, ForwardDiff.Partials{$id,Float64}))
134114
end
135-
return
115+
return _create_binary_switch(1:MAX_CHUNK, exprs)
116+
end
117+
118+
@eval function _hessian_slice_inner(d, ex, id::Int)
119+
$(_generate_hessian_slice_inner())
120+
return error("Invalid chunk size: $id")
136121
end
137122

138123
function _hessian_slice_inner(d, ex, ::Type{T}) where {T}

test/Nonlinear/ReverseAD.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1385,6 +1385,20 @@ function test_eval_user_defined_operator_type_mismatch()
13851385
return
13861386
end
13871387

1388+
function test_generate_hessian_slice_inner()
1389+
# Test that it evaluates without error. The code contents are tested
1390+
# elsewhere.
1391+
MOI.Nonlinear.ReverseAD._generate_hessian_slice_inner()
1392+
d = ex = nothing # These arguments are untyped and not needed for this test
1393+
for id in [0, MAX_CHUNK + 1]
1394+
@test_throws(
1395+
ErrorException("Invalid chunk size: $id"),
1396+
MOI.Nonlinear.ReverseAD._hessian_slice_inner(d, ex, id),
1397+
)
1398+
end
1399+
return
1400+
end
1401+
13881402
end # module
13891403

13901404
TestReverseAD.runtests()

0 commit comments

Comments
 (0)