Skip to content

Commit 305277a

Browse files
oschulzdevmotion
andcommitted
Add function setladj
Co-authored-by: David Widmann <[email protected]>
1 parent bdcf78c commit 305277a

File tree

10 files changed

+189
-5
lines changed

10 files changed

+189
-5
lines changed

Project.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,24 @@ uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
33
version = "0.1.7"
44

55
[deps]
6+
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
89

10+
[weakdeps]
11+
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
12+
13+
[extensions]
14+
ChangesOfVariablesInverseFunctionsExt = "InverseFunctions"
15+
916
[compat]
17+
InverseFunctions = "0.1"
1018
julia = "1"
1119

1220
[extras]
1321
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
1422
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
23+
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
1524

1625
[targets]
17-
test = ["Documenter", "ForwardDiff"]
26+
test = ["Documenter", "InverseFunctions", "ForwardDiff"]

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ changes for functions that perform a change of variables (like coordinate
1212
transformations).
1313

1414
`ChangesOfVariables` is a very lightweight package and has no dependencies
15-
beyond `Base`, `LinearAlgebra`, `Test`.
15+
beyond `Base`, `LinearAlgebra` and `Test` (plus a weak depdendency on
16+
`InverseFunctions`).
1617

1718
## Documentation
1819

docs/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
[deps]
22
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
33
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
4+
InverseFunctions = "3587e190-3f89-42d0-90ee-14403ec27112"
45
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
56
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
67

docs/make.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ using ChangesOfVariables
1111
DocMeta.setdocmeta!(
1212
ChangesOfVariables,
1313
:DocTestSetup,
14-
:(using ChangesOfVariables);
14+
:(using ChangesOfVariables, InverseFunctions);
1515
recursive=true,
1616
)
1717

docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,17 @@
55
```@docs
66
with_logabsdet_jacobian
77
NoLogAbsDetJacobian
8+
setladj
89
```
910

1011
## Test utility
1112

1213
```@docs
1314
ChangesOfVariables.test_with_logabsdet_jacobian
1415
```
16+
17+
## Additional functionality
18+
19+
```@docs
20+
ChangesOfVariables.FunctionWithLADJ
21+
```
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
module ChangesOfVariablesInverseFunctionsExt
2+
3+
using ChangesOfVariables
4+
using InverseFunctions
5+
6+
7+
struct InverseFunctionWithLADJ{InvF,LADJF} <: Function
8+
inv_f::InvF
9+
ladjf::LADJF
10+
end
11+
InverseFunctionWithLADJ(::Type{InvF}, ladjf::LADJF) where {InvF,LADJF} = InverseFunctionWithLADJ{Type{InvF},LADJF}(InvF,ladjf)
12+
InverseFunctionWithLADJ(inv_f::InvF, ::Type{LADJF}) where {InvF,LADJF} = InverseFunctionWithLADJ{InvF,Type{LADJF}}(inv_f,LADJF)
13+
InverseFunctionWithLADJ(::Type{InvF}, ::Type{LADJF}) where {InvF,LADJF} = InverseFunctionWithLADJ{Type{InvF},Type{LADJF}}(InvF,LADJF)
14+
15+
(f::InverseFunctionWithLADJ)(y) = f.inv_f(y)
16+
17+
function ChangesOfVariables.with_logabsdet_jacobian(f::InverseFunctionWithLADJ, y)
18+
x = f.inv_f(y)
19+
return x, -f.ladjf(x)
20+
end
21+
22+
InverseFunctions.inverse(f::ChangesOfVariables.FunctionWithLADJ) = InverseFunctionWithLADJ(inverse(f.f), f.ladjf)
23+
InverseFunctions.inverse(f::InverseFunctionWithLADJ) = ChangesOfVariables.FunctionWithLADJ(inverse(f.inv_f), f.ladjf)
24+
25+
26+
@static if isdefined(InverseFunctions, :FunctionWithInverse)
27+
function ChangesOfVariables.with_logabsdet_jacobian(f::InverseFunctions.FunctionWithInverse, x)
28+
ChangesOfVariables.with_logabsdet_jacobian(f.f, x)
29+
end
30+
end
31+
32+
end # module ChangesOfVariablesInverseFunctionsExt

src/ChangesOfVariables.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,11 @@ using LinearAlgebra
1313
using Test
1414

1515
include("with_ladj.jl")
16+
include("setladj.jl")
1617
include("test.jl")
1718

19+
@static if !isdefined(Base, :get_extension)
20+
include("../ext/ChangesOfVariablesInverseFunctionsExt.jl")
21+
end
22+
1823
end # module

src/setladj.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT).
2+
3+
4+
"""
5+
struct FunctionWithLADJ{F,LADJF} <: Function
6+
7+
A function with an separate function to compute it's `logabddet(J)`.
8+
9+
Do not construct directly, use [`setladj(f, ladjf)`](@ref) instead.
10+
"""
11+
struct FunctionWithLADJ{F,LADJF} <: Function
12+
f::F
13+
ladjf::LADJF
14+
end
15+
FunctionWithLADJ(::Type{F}, ladjf::LADJF) where {F,LADJF} = FunctionWithLADJ{Type{F},LADJF}(F,ladjf)
16+
FunctionWithLADJ(f::F, ::Type{LADJF}) where {F,LADJF} = FunctionWithLADJ{F,Type{LADJF}}(f,LADJF)
17+
FunctionWithLADJ(::Type{F}, ::Type{LADJF}) where {F,LADJF} = FunctionWithLADJ{Type{F},Type{LADJF}}(F,LADJF)
18+
19+
(f::FunctionWithLADJ)(x) = f.f(x)
20+
21+
with_logabsdet_jacobian(f::FunctionWithLADJ, x) = f.f(x), f.ladjf(x)
22+
23+
24+
"""
25+
setladj(f, ladjf)::Function
26+
27+
Return a function that behaves like `f` in general and which has
28+
`with_logabsdet_jacobian(f, x) = f(x), ladjf(x)`.
29+
30+
Useful in cases where [`with_logabsdet_jacobian`](@ref) is not defined
31+
for `f`, or if `f` needs to be assigned a LADJ-calculation that is
32+
only valid within a given context, e.g. only for a
33+
limited argument type/range that is guaranteed by the use case but
34+
not in general, or that is optimized to a custom use case.
35+
36+
For example, `CUDA.CuArray` has no `with_logabsdet_jacobian` defined,
37+
but may be used to switch computing device for a part of a
38+
heterogenous computing function chain. Likewise, one may want to
39+
switch numerical precision for a part of a calculation.
40+
41+
The function (wrapper) returned by `setladj` supports
42+
[`InverseFunctions.inverse`](https://github.com/JuliaMath/InverseFunctions.jl)
43+
if `f` does so.
44+
45+
Example:
46+
47+
```jldoctest setladj
48+
VERSION < v"1.6" || begin # Support for ∘ requires Julia >= v1.6
49+
# Increases precition before calculation exp:
50+
foo = exp ∘ setladj(setinverse(Float64, Float32), _ -> 0)
51+
52+
# A log-value from some low-precision (e.g. GPU) computation:
53+
log_x = Float32(100)
54+
55+
# f(log_x) would return Inf32 without going to Float64:
56+
y, ladj = with_logabsdet_jacobian(foo, log_x)
57+
58+
r_log_x, ladj_inv = with_logabsdet_jacobian(inverse(foo), y)
59+
60+
ladj ≈ 100 ≈ -ladj_inv && r_log_x ≈ log_x
61+
end
62+
# output
63+
64+
true
65+
```
66+
"""
67+
setladj(f, ladjf) = FunctionWithLADJ(_unwrap_f(f), ladjf)
68+
export setladj
69+
70+
_unwrap_f(f) = f
71+
_unwrap_f(f::FunctionWithLADJ) = f.f

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@ import Documenter
77
Test.@testset "Package ChangesOfVariables" begin
88
include("test_test.jl")
99
include("test_with_ladj.jl")
10+
include("test_setladj.jl")
1011

1112
# doctests
1213
Documenter.DocMeta.setdocmeta!(
1314
ChangesOfVariables,
1415
:DocTestSetup,
15-
:(using ChangesOfVariables);
16+
:(using ChangesOfVariables, InverseFunctions);
1617
recursive=true,
1718
)
1819
Documenter.doctest(ChangesOfVariables)
1920
end # testset
20-

test/test_setladj.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT).
2+
3+
using Test
4+
using ChangesOfVariables
5+
using InverseFunctions
6+
7+
const ChangesOfVariablesInverseFunctionsExt = if isdefined(Base, :get_extension)
8+
Base.get_extension(ChangesOfVariables, :ChangesOfVariablesInverseFunctionsExt)
9+
else
10+
ChangesOfVariables.ChangesOfVariablesInverseFunctionsExt
11+
end
12+
const InverseFunctionWithLADJ = ChangesOfVariablesInverseFunctionsExt.InverseFunctionWithLADJ
13+
14+
include("getjacobian.jl")
15+
16+
17+
# Dummy testing type that looks like something that represents abstract zeros:
18+
struct _Zero{T} end
19+
_Zero(::T) where {T} = _Zero{T}()
20+
21+
22+
@testset "setladj" begin
23+
@test @inferred(setladj(Real, _Zero)) isa ChangesOfVariables.FunctionWithLADJ{Type{Real},Type{_Zero}}
24+
@test @inferred(ChangesOfVariables.FunctionWithLADJ(Real, _Zero)) isa ChangesOfVariables.FunctionWithLADJ{Type{Real},Type{_Zero}}
25+
@test @inferred(ChangesOfVariables.FunctionWithLADJ(widen, _Zero)) isa ChangesOfVariables.FunctionWithLADJ{typeof(widen),Type{_Zero}}
26+
@test @inferred(ChangesOfVariables.FunctionWithLADJ(Real, zero)) isa ChangesOfVariables.FunctionWithLADJ{Type{Real},typeof(zero)}
27+
@test @inferred(ChangesOfVariables.FunctionWithLADJ(widen, zero)) isa ChangesOfVariables.FunctionWithLADJ{typeof(widen),typeof(zero)}
28+
29+
@test @inferred(InverseFunctionWithLADJ(Real, _Zero)) isa InverseFunctionWithLADJ{Type{Real},Type{_Zero}}
30+
@test @inferred(InverseFunctionWithLADJ(widen, _Zero)) isa InverseFunctionWithLADJ{typeof(widen),Type{_Zero}}
31+
@test @inferred(InverseFunctionWithLADJ(Real, zero)) isa InverseFunctionWithLADJ{Type{Real},typeof(zero)}
32+
@test @inferred(InverseFunctionWithLADJ(widen, zero)) isa InverseFunctionWithLADJ{typeof(widen),typeof(zero)}
33+
34+
@test @inferred(setladj(setladj(exp, x -> 0), x -> x)) isa ChangesOfVariables.FunctionWithLADJ{typeof(exp)}
35+
ChangesOfVariables.test_with_logabsdet_jacobian(setladj(setladj(exp, x -> 0), x -> x), 1.7, getjacobian)
36+
37+
x = 4.2
38+
y = x^2
39+
40+
f_fwd = setladj(x -> x^2, x -> log(2*x))
41+
f_inv = setladj(y -> sqrt(y), y -> log(inv(2*sqrt(y))))
42+
ChangesOfVariables.test_with_logabsdet_jacobian(f_fwd, x, getjacobian)
43+
ChangesOfVariables.test_with_logabsdet_jacobian(f_inv, y, getjacobian)
44+
45+
f = @inferred setladj(setinverse(x -> x^2, x -> sqrt(x)), x -> log(2*x))
46+
@test @inferred(f(x)) == y
47+
ChangesOfVariables.test_with_logabsdet_jacobian(f, x, getjacobian)
48+
ChangesOfVariables.test_with_logabsdet_jacobian(inverse(f), y, getjacobian)
49+
ChangesOfVariables.test_with_logabsdet_jacobian(inverse(inverse(f)), x, getjacobian)
50+
@inferred(inverse(inverse(f))) isa ChangesOfVariables.FunctionWithLADJ
51+
52+
@static if isdefined(InverseFunctions, :setinverse)
53+
g = setinverse(f_fwd, f_inv)
54+
ChangesOfVariables.test_with_logabsdet_jacobian(g, x, getjacobian)
55+
ChangesOfVariables.test_with_logabsdet_jacobian(inverse(g), y, getjacobian)
56+
ChangesOfVariables.test_with_logabsdet_jacobian(inverse(inverse(g)), x, getjacobian)
57+
end
58+
end

0 commit comments

Comments
 (0)