-
Notifications
You must be signed in to change notification settings - Fork 42
add ForwardDiff extension #178
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
longemen3000
wants to merge
13
commits into
JuliaPhysics:master
Choose a base branch
from
longemen3000:forwarddiff-ext
base: master
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from 4 commits
Commits
Show all changes
13 commits
Select commit
Hold shift + click to select a range
61679cf
add ForwardDiff extension
longemen3000 cf506b0
Update runtests.jl
longemen3000 3c6ff6f
add derivatives for measurement, value, uncertainty
longemen3000 07c66fc
improve runtests
longemen3000 8cc9da5
add more tests and descriptions of what the code does
longemen3000 603ccee
more tests
longemen3000 f128b66
add oneunit(Dual{Measurement})
longemen3000 ae39ad2
add a promote_rule for Measurement and BigFloat
longemen3000 757c10d
typo
longemen3000 389ef3c
add more extensive test for ternaries
longemen3000 3891902
remove unused promote_rule
longemen3000 ae0f7ca
try to ternary (two duals and one measurement)
longemen3000 2a41c0c
remove unnecessary ifs
longemen3000 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,188 @@ | ||
| module MeasurementsForwardDiffExt | ||
|
|
||
| using ForwardDiff: Dual, DiffRules, NaNMath, LogExpFunctions, SpecialFunctions,≺ | ||
| using Measurements: Measurement, Measurements | ||
| import Base: +,-,/,*,promote_rule | ||
| using ForwardDiff: AMBIGUOUS_TYPES, partials, values, Partials, value | ||
| using ForwardDiff: ForwardDiff | ||
|
|
||
| #patch this until is fixed in ForwardDiff | ||
|
|
||
| @generated function ForwardDiff.construct_seeds(::Type{Partials{N,V}}) where {N,V<:Measurement} | ||
| return Expr(:tuple, [:(single_seed(Partials{N,V}, Val{$i}())) for i in 1:N]...) | ||
| end | ||
|
|
||
| #needs redefinition here, because generated functions don't allow extra definitions | ||
| @generated function single_seed(::Type{Partials{N,V}}, ::Val{i}) where {N,V,i} | ||
| ex = Expr(:tuple, [ifelse(i === j, :(oneunit(V)), :(zero(V))) for j in 1:N]...) | ||
| return :(Partials($(ex))) | ||
| end | ||
|
|
||
| function promote_rule(::Type{Measurement{V}}, ::Type{Dual{T, V, N}}) where {T,V,N} | ||
| Dual{Measurement{T}, V, N} | ||
| end | ||
|
|
||
| function promote_rule(::Type{Measurement{V1}}, ::Type{Dual{T, V2, N}}) where {V1<:AbstractFloat, T, V2, N} | ||
| Vx = promote_rule(Measurement{V1},V2) | ||
| return Dual{T , Vx, N} | ||
| end | ||
|
|
||
| function overload_ambiguous_binary(M,f) | ||
| Mf = :($M.$f) | ||
| return quote | ||
| @inline function $Mf(x::Dual{Tx}, y::Measurement) where {Tx} | ||
| ∂y = Dual{Tx}(y) | ||
| $Mf(x,∂y) | ||
| end | ||
|
|
||
| @inline function $Mf(x::Measurement,y::Dual{Ty}) where {Ty} | ||
| ∂x = Dual{Ty}(x) | ||
| $Mf(∂x,y) | ||
| end | ||
| end | ||
| end | ||
|
|
||
| macro define_ternary_dual_op2(f, xyz_body, xy_body, xz_body, yz_body, x_body, y_body, z_body) | ||
| FD = ForwardDiff | ||
| R = Measurement | ||
| defs = quote | ||
| @inline $(f)(x::$FD.Dual{Txy}, y::$FD.Dual{Txy}, z::$R) where {Txy} = $xy_body | ||
| @inline $(f)(x::$FD.Dual{Tx}, y::$FD.Dual{Ty}, z::$R) where {Tx, Ty} = Ty ≺ Tx ? $x_body : $y_body | ||
| @inline $(f)(x::$FD.Dual{Txz}, y::$R, z::$FD.Dual{Txz}) where {Txz} = $xz_body | ||
| @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$FD.Dual{Tz}) where {Tx,Tz} = Tz ≺ Tx ? $x_body : $z_body | ||
| @inline $(f)(x::$R, y::$FD.Dual{Tyz}, z::$FD.Dual{Tyz}) where {Tyz} = $yz_body | ||
| @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$FD.Dual{Tz}) where {Ty,Tz} = Tz ≺ Ty ? $y_body : $z_body | ||
| end | ||
| for Q in AMBIGUOUS_TYPES | ||
| expr = quote | ||
| @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$Q) where {Tx} = $x_body | ||
| @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$Q) where {Ty} = $y_body | ||
| @inline $(f)(x::$R, y::$Q, z::$FD.Dual{Tz}) where {Tz} = $z_body | ||
| end | ||
| append!(defs.args, expr.args) | ||
| end | ||
| expr = quote | ||
| @inline $(f)(x::$FD.Dual{Tx}, y::$R, z::$R) where {Tx} = $x_body | ||
| @inline $(f)(x::$R, y::$FD.Dual{Ty}, z::$R) where {Ty} = $y_body | ||
| @inline $(f)(x::$R, y::$R, z::$FD.Dual{Tz}) where {Tz} = $z_body | ||
| end | ||
| append!(defs.args, expr.args) | ||
| return esc(defs) | ||
| end | ||
|
|
||
| #use DiffRules.jl rules | ||
|
|
||
| for (M, f, arity) in DiffRules.diffrules(filter_modules = nothing) | ||
| if (M, f) in ((:Base, :^), (:NaNMath, :pow)) | ||
| continue # Skip methods which we define elsewhere. | ||
| elseif !(isdefined(@__MODULE__, M) && isdefined(getfield(@__MODULE__, M), f)) | ||
| continue # Skip rules for methods not defined in the current scope | ||
| end | ||
| if arity == 2 | ||
| eval(overload_ambiguous_binary(M,f)) | ||
| else | ||
| # error("ForwardDiff currently only knows how to autogenerate Dual definitions for unary and binary functions.") | ||
| # However, the presence of N-ary rules need not cause any problems here, they can simply be ignored. | ||
| end | ||
| end | ||
|
|
||
| #ternary overloads | ||
| @define_ternary_dual_op2( | ||
| Base.hypot, | ||
| ForwardDiff.calc_hypot(x, y, z, Txyz), | ||
| ForwardDiff.calc_hypot(x, y, z, Txy), | ||
| ForwardDiff.calc_hypot(x, y, z, Txz), | ||
| ForwardDiff.calc_hypot(x, y, z, Tyz), | ||
| ForwardDiff.calc_hypot(x, y, z, Tx), | ||
| ForwardDiff.calc_hypot(x, y, z, Ty), | ||
| ForwardDiff.calc_hypot(x, y, z, Tz), | ||
| ) | ||
|
|
||
| @define_ternary_dual_op2( | ||
| Base.fma, | ||
| ForwardDiff.calc_fma_xyz(x, y, z), # xyz_body | ||
| ForwardDiff.calc_fma_xy(x, y, z), # xy_body | ||
| ForwardDiff.calc_fma_xz(x, y, z), # xz_body | ||
| Base.fma(y, x, z), # yz_body | ||
| Dual{Tx}(Base.fma(value(x), y, z), partials(x) * y), # x_body | ||
| Base.fma(y, x, z), # y_body | ||
| Dual{Tz}(Base.fma(x, y, value(z)), partials(z)) # z_body | ||
| ) | ||
|
|
||
| @define_ternary_dual_op2( | ||
| Base.muladd, | ||
| ForwardDiff.calc_muladd_xyz(x, y, z), # xyz_body | ||
| ForwardDiff.calc_muladd_xy(x, y, z), # xy_body | ||
| ForwardDiff.calc_muladd_xz(x, y, z), # xz_body | ||
| Base.muladd(y, x, z), # yz_body | ||
| Dual{Tx}(Base.muladd(value(x), y, z), partials(x) * y), # x_body | ||
| Base.muladd(y, x, z), # y_body | ||
| Dual{Tz}(Base.muladd(x, y, value(z)), partials(z)) # z_body | ||
| ) | ||
|
|
||
| #= | ||
| Derivatives of Measurements.value and Measurements.uncertainty | ||
| Apply those functions to the real and partial part. | ||
| =# | ||
| function Measurements.value(x::Dual{T,V,N}) where {T, V <: Measurement, N} | ||
| x_value = Measurements.value(ForwardDiff.value(x)) | ||
| p = partials(x) | ||
| p_value = ntuple(i -> Measurements.value(p[i]),Val(N)) | ||
| return ForwardDiff.Dual{T}(x_value,Partials(p_value)) | ||
| end | ||
|
|
||
| function Measurements.uncertainty(x::Dual{T,V,N}) where {T, V <: Measurement, N} | ||
| x_err = Measurements.uncertainty(ForwardDiff.value(x)) | ||
| p = partials(x) | ||
| p_err = ntuple(i -> Measurements.uncertainty(p[i]),Val(N)) | ||
| return ForwardDiff.Dual{T}(x_err,Partials(p_err)) | ||
| end | ||
|
|
||
| #= | ||
| start of derivatives of Measurements.measurement | ||
|
|
||
| Derivative with respect to the value: | ||
| f(x) = measurement(n*x,m*y), derivative(f,x) = n ± 0 | ||
|
|
||
| Derivative with respect to the uncertainty: | ||
| f(x) = measurement(n*x,m*y), derivative(f,x) = 0 ± m | ||
| =# | ||
| function dmeasurement_val(x::Dual{T,V,N},err::Real) where {T,V,N} | ||
|
|
||
| val = ForwardDiff.value(x) | ||
| _1,_0 = oneunit(val),zero(val) | ||
| x_value = Measurements.measurement(val,err) | ||
| x_der = Measurements.measurement(_1,_0) | ||
| v = Dual{T}(x_value,x_der * partials(x)) | ||
| end | ||
|
|
||
| function dmeasurement_err(x::Real,err::Dual{T,V,N}) where {T,V,N} | ||
| errval = ForwardDiff.value(err) | ||
| _1,_0 = oneunit(errval),zero(errval) | ||
| x_value = Measurements.measurement(x,errval) | ||
| x_der = Measurements.measurement(_0,_1) | ||
| v = Dual{T}(x_value,x_der * partials(err)) | ||
| end | ||
|
|
||
| function dmeasurement_val_and_err(x::Dual{T,V1,N},err::Dual{T,V2,N}) where {T,V1,V2,N} | ||
| xval = ForwardDiff.value(x) | ||
| errval = ForwardDiff.value(err) | ||
| xp = partials(x) | ||
| errp = partials(err) | ||
| measurement_primal = Measurements.measurement(xval,errval) | ||
| measurement_partials_tuple = ntuple(i -> Measurements.measurement(xp[i],errp[i]),Val(N)) | ||
| return Dual{T}(measurement_primal,Partials(measurement_partials_tuple)) | ||
| end | ||
|
|
||
| function Measurements.measurement(x::ForwardDiff.Dual{T,V,N}) where {T,V,N} | ||
| return dmeasurement_val(x,zero(ForwardDiff.value(x))) | ||
| end | ||
|
|
||
| ForwardDiff.@define_binary_dual_op( | ||
| Measurements.measurement, | ||
| dmeasurement_val_and_err(x,y), | ||
| dmeasurement_val(x,y), | ||
| dmeasurement_err(x,y), | ||
| ) | ||
|
|
||
| end #module | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.