Skip to content
Open
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"

[weakdeps]
BaseType = "7fbed51b-1ef5-4d67-9085-a4a9b26f478c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Expand All @@ -18,6 +19,7 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[extensions]
MeasurementsBaseTypeExt = "BaseType"
MeasurementsForwardDiffExt = "ForwardDiff"
MeasurementsJunoExt = "Juno"
MeasurementsMakieExt = "Makie"
MeasurementsRecipesBaseExt = "RecipesBase"
Expand All @@ -28,6 +30,7 @@ MeasurementsUnitfulExt = "Unitful"
Aqua = "0.8"
BaseType = "0.2"
Calculus = "0.4.1, 0.5"
ForwardDiff = "0.10.36, 1"
Juno = "0.8"
LinearAlgebra = "<0.0.1, 1"
Makie = "0.21, 0.22"
Expand All @@ -43,6 +46,7 @@ julia = "1.10"
[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
BaseType = "7fbed51b-1ef5-4d67-9085-a4a9b26f478c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
Expand All @@ -53,4 +57,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"

[targets]
test = ["Aqua", "Makie", "BaseType", "QuadGK", "RecipesBase", "SpecialFunctions", "Statistics", "Test", "Unitful"]
test = ["Aqua", "Makie", "BaseType", "QuadGK", "RecipesBase", "SpecialFunctions", "Statistics", "Test", "Unitful", "ForwardDiff"]
188 changes: 188 additions & 0 deletions ext/MeasurementsForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
module MeasurementsForwardDiffExt

using ForwardDiff: Dual, DiffRules, NaNMath, LogExpFunctions, SpecialFunctions,≺
using Measurements: Measurement, Measurements
import Base: +,-,/,*,promote_rule
using ForwardDiff: AMBIGUOUS_TYPES, partials, values, Partials, value
using ForwardDiff: ForwardDiff

#patch this until is fixed in ForwardDiff

@generated function ForwardDiff.construct_seeds(::Type{Partials{N,V}}) where {N,V<:Measurement}
return Expr(:tuple, [:(single_seed(Partials{N,V}, Val{$i}())) for i in 1:N]...)

Check warning on line 12 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L11-L12

Added lines #L11 - L12 were not covered by tests
end

#needs redefinition here, because generated functions don't allow extra definitions
@generated function single_seed(::Type{Partials{N,V}}, ::Val{i}) where {N,V,i}
ex = Expr(:tuple, [ifelse(i === j, :(oneunit(V)), :(zero(V))) for j in 1:N]...)
return :(Partials($(ex)))

Check warning on line 18 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L16-L18

Added lines #L16 - L18 were not covered by tests
end

function promote_rule(::Type{Measurement{V}}, ::Type{Dual{T, V, N}}) where {T,V,N}
Dual{Measurement{T}, V, N}

Check warning on line 22 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L21-L22

Added lines #L21 - L22 were not covered by tests
end

function promote_rule(::Type{Measurement{V1}}, ::Type{Dual{T, V2, N}}) where {V1<:AbstractFloat, T, V2, N}
Vx = promote_rule(Measurement{V1},V2)
return Dual{T , Vx, N}

Check warning on line 27 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L25-L27

Added lines #L25 - L27 were not covered by tests
end

function overload_ambiguous_binary(M,f)
Mf = :($M.$f)
return quote
@inline function $Mf(x::Dual{Tx}, y::Measurement) where {Tx}
∂y = Dual{Tx}(y)
$Mf(x,∂y)
end

@inline function $Mf(x::Measurement,y::Dual{Ty}) where {Ty}
∂x = Dual{Ty}(x)
$Mf(∂x,y)
end
end
end

macro define_ternary_dual_op2(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_body, z_body)
FD = ForwardDiff
R = Measurement
defs = quote
@inline $(f)(x::$FD.Dual{Txy}, y::$FD.Dual{Txy}, z::$R) where {Txy} = $xy_body
@inline $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Ty}, z::$R) where {Tx, Ty} = Ty ≺ Tx ? $x_body : $y_body
@inline $(f)(x::$FD.Dual{Txz}, y::$R, z::$FD.Dual{Txz}) where {Txz} = $xz_body
@inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$FD.Dual{Tz}) where {Tx,Tz} = Tz ≺ Tx ? $x_body : $z_body
@inline $(f)(x::$R, y::$FD.Dual{Tyz}, z::$FD.Dual{Tyz}) where {Tyz} = $yz_body
@inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$FD.Dual{Tz}) where {Ty,Tz} = Tz ≺ Ty ? $y_body : $z_body

Check warning on line 54 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L49-L54

Added lines #L49 - L54 were not covered by tests
end
for Q in AMBIGUOUS_TYPES
expr = quote
@inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$Q) where {Tx} = $x_body
@inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$Q) where {Ty} = $y_body
@inline $(f)(x::$R, y::$Q, z::$FD.Dual{Tz}) where {Tz} = $z_body

Check warning on line 60 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L60

Added line #L60 was not covered by tests
end
append!(defs.args, expr.args)
end
expr = quote
@inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$R) where {Tx} = $x_body
@inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$R) where {Ty} = $y_body
@inline $(f)(x::$R, y::$R, z::$FD.Dual{Tz}) where {Tz} = $z_body

Check warning on line 67 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L65-L67

Added lines #L65 - L67 were not covered by tests
end
append!(defs.args, expr.args)
return esc(defs)
end

#use DiffRules.jl rules

for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing)
if (M, f) in ((:Base, :^), (:NaNMath, :pow))
continue # Skip methods which we define elsewhere.
elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f))
continue # Skip rules for methods not defined in the current scope

Check warning on line 79 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L79

Added line #L79 was not covered by tests
end
if arity == 2
eval(overload_ambiguous_binary(M,f))
else

Check warning on line 83 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L83

Added line #L83 was not covered by tests
# error("ForwardDiff currently only knows how to autogenerate Dual definitions for unary and binary functions.")
# However, the presence of N-ary rules need not cause any problems here, they can simply be ignored.
end
end

#ternary overloads
@define_ternary_dual_op2(
Base.hypot,
ForwardDiff.calc_hypot(x, y, z, Txyz),
ForwardDiff.calc_hypot(x, y, z, Txy),
ForwardDiff.calc_hypot(x, y, z, Txz),
ForwardDiff.calc_hypot(x, y, z, Tyz),
ForwardDiff.calc_hypot(x, y, z, Tx),
ForwardDiff.calc_hypot(x, y, z, Ty),
ForwardDiff.calc_hypot(x, y, z, Tz),
)

@define_ternary_dual_op2(
Base.fma,
ForwardDiff.calc_fma_xyz(x, y, z), # xyz_body
ForwardDiff.calc_fma_xy(x, y, z), # xy_body
ForwardDiff.calc_fma_xz(x, y, z), # xz_body
Base.fma(y, x, z), # yz_body
Dual{Tx}(Base.fma(value(x), y, z), partials(x) * y), # x_body
Base.fma(y, x, z), # y_body
Dual{Tz}(Base.fma(x, y, value(z)), partials(z)) # z_body
)

@define_ternary_dual_op2(
Base.muladd,
ForwardDiff.calc_muladd_xyz(x, y, z), # xyz_body
ForwardDiff.calc_muladd_xy(x, y, z), # xy_body
ForwardDiff.calc_muladd_xz(x, y, z), # xz_body
Base.muladd(y, x, z), # yz_body
Dual{Tx}(Base.muladd(value(x), y, z), partials(x) * y), # x_body
Base.muladd(y, x, z), # y_body
Dual{Tz}(Base.muladd(x, y, value(z)), partials(z)) # z_body
)

#=
Derivatives of Measurements.value and Measurements.uncertainty
Apply those functions to the real and partial part.
=#
function Measurements.value(x::Dual{T,V,N}) where {T, V <: Measurement, N}
x_value = Measurements.value(ForwardDiff.value(x))
p = partials(x)
p_value = ntuple(i -> Measurements.value(p[i]),Val(N))
return ForwardDiff.Dual{T}(x_value,Partials(p_value))
end

function Measurements.uncertainty(x::Dual{T,V,N}) where {T, V <: Measurement, N}
x_err = Measurements.uncertainty(ForwardDiff.value(x))
p = partials(x)
p_err = ntuple(i -> Measurements.uncertainty(p[i]),Val(N))
return ForwardDiff.Dual{T}(x_err,Partials(p_err))
end

#=
start of derivatives of Measurements.measurement

Derivative with respect to the value:
f(x) = measurement(n*x,m*y), derivative(f,x) = n ± 0

Derivative with respect to the uncertainty:
f(x) = measurement(n*x,m*y), derivative(f,x) = 0 ± m
=#
function dmeasurement_val(x::Dual{T,V,N},err::Real) where {T,V,N}

val = ForwardDiff.value(x)
_1,_0 = oneunit(val),zero(val)
x_value = Measurements.measurement(val,err)
x_der = Measurements.measurement(_1,_0)
v = Dual{T}(x_value,x_der * partials(x))
end

function dmeasurement_err(x::Real,err::Dual{T,V,N}) where {T,V,N}
errval = ForwardDiff.value(err)
_1,_0 = oneunit(errval),zero(errval)
x_value = Measurements.measurement(x,errval)
x_der = Measurements.measurement(_0,_1)
v = Dual{T}(x_value,x_der * partials(err))
end

function dmeasurement_val_and_err(x::Dual{T,V1,N},err::Dual{T,V2,N}) where {T,V1,V2,N}
xval = ForwardDiff.value(x)
errval = ForwardDiff.value(err)
xp = partials(x)
errp = partials(err)
measurement_primal = Measurements.measurement(xval,errval)
measurement_partials_tuple = ntuple(i -> Measurements.measurement(xp[i],errp[i]),Val(N))
return Dual{T}(measurement_primal,Partials(measurement_partials_tuple))
end

function Measurements.measurement(x::ForwardDiff.Dual{T,V,N}) where {T,V,N}
return dmeasurement_val(x,zero(ForwardDiff.value(x)))

Check warning on line 178 in ext/MeasurementsForwardDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/MeasurementsForwardDiffExt.jl#L177-L178

Added lines #L177 - L178 were not covered by tests
end

ForwardDiff.@define_binary_dual_op(
Measurements.measurement,
dmeasurement_val_and_err(x,y),
dmeasurement_val(x,y),
dmeasurement_err(x,y),
)

end #module
33 changes: 33 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using Measurements, SpecialFunctions, QuadGK, Calculus, BaseType, Makie
using Test, LinearAlgebra, Statistics, Unitful, Printf, Aqua
using ForwardDiff

Aqua.test_all(Measurements)

Expand Down Expand Up @@ -1067,3 +1068,35 @@ end
@test base_numeric_type(typeof(x)) == T
end
end


fd_f1(x,y) = measurement(2x,3y)
fd_f2(x) = fd_f1(x,x)
fd_f3(x,y) = muladd(x,y,1)
fd_f4(x,y) = value(fd_f1(x,y))
fd_f5(x,y) = uncertainty(fd_f1(x,y))

@testset "ForwardDiff" begin
x1 = 1.0 ± 0.1
y1 = 30.0 ± 0.7
#some common operations, no special handling in the extension, just wrapping in a dual
@test ForwardDiff.derivative(Base.Fix1(+,x1),y1) == 1.0 ± 0.0
@test ForwardDiff.derivative(Base.Fix1(+,y1),x1) == 1.0 ± 0.0
@test ForwardDiff.derivative(Base.Fix1(*,y1),x1) == y1
@test ForwardDiff.derivative(Base.Fix1(*,x1),y1) == x1
@test ForwardDiff.derivative(Base.Fix2(/,y1),x1) == 1/y1
@test ForwardDiff.derivative(Base.Fix1(/,x1),y1) == -x1/(y1*y1)

#test ternary op
@test ForwardDiff.derivative(Base.Fix1(fd_f3,y1),x1) == y1
@test ForwardDiff.derivative(Base.Fix1(fd_f3,x1),y1) == x1

#derivatives of Measurements.measurement
@test ForwardDiff.derivative(Base.Fix1(fd_f1,1.0),1.213) == 0.0 ± 3.0
@test ForwardDiff.derivative(Base.Fix2(fd_f1,1.0),1.213) == 2.0 ± 0.0
@test ForwardDiff.derivative(fd_f2,1.213) == 2.0 ± 3.0

#test value/uncertainty getters
@test ForwardDiff.derivative(Base.Fix2(fd_f4,1.0),1.213) == 2.0
@test ForwardDiff.derivative(Base.Fix1(fd_f5,1.0),1.213) == 3.0
end
Loading