Skip to content

Commit e19783e

Browse files
committed
added: GradientBuffer
1 parent 3640db7 commit e19783e

File tree

1 file changed

+18
-4
lines changed

1 file changed

+18
-4
lines changed

src/general.jl

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,28 @@ const DEFAULT_EWT = 0.0
99
"Abstract type for all differentiation buffers."
1010
abstract type DifferentiationBuffer end
1111

12-
"Struct with both function and configuration for ForwardDiff differentiation."
13-
struct JacobianBuffer{FT<:Function, CT<:ForwardDiff.JacobianConfig} <: DifferentiationBuffer
12+
function Base.show(io::IO, buffer::DifferentiationBuffer)
13+
return print(io, "DifferentiationBuffer with a $(typeof(buffer.config).name.name)")
14+
end
15+
16+
"Struct with both function and configuration for ForwardDiff gradient."
17+
struct GradientBuffer{FT<:Function, CT<:ForwardDiff.GradientConfig} <: DifferentiationBuffer
1418
f!::FT
1519
config::CT
1620
end
1721

18-
function Base.show(io::IO, buffer::DifferentiationBuffer)
19-
return print(io, "DifferentiationBuffer with a $(typeof(buffer.config).name.name)")
22+
GradientBuffer(f!, x) = GradientBuffer(f!, ForwardDiff.GradientConfig(f!, x))
23+
24+
function gradient!(
25+
g, buffer::GradientBuffer, x
26+
)
27+
return ForwardDiff.gradient!(g, buffer.f!, x, buffer.config)
28+
end
29+
30+
"Struct with both function and configuration for ForwardDiff Jacobian."
31+
struct JacobianBuffer{FT<:Function, CT<:ForwardDiff.JacobianConfig} <: DifferentiationBuffer
32+
f!::FT
33+
config::CT
2034
end
2135

2236
"Create a JacobianBuffer with function `f!`, output `y` and input `x`."

0 commit comments

Comments
 (0)