1
1
# This file is a part of ChangesOfVariables.jl, licensed under the MIT License (MIT).
2
2
3
3
4
- _to_realvec_and_back (V:: AbstractVector{<:Real} ) = V, identity
5
- _to_realvec_and_back (x:: Real ) = [x], V -> V[1 ]
6
- _to_realvec_and_back (x:: Complex ) = [real (x), imag (x)], V -> Complex (V[1 ], V[2 ])
7
- _to_realvec_and_back (x:: NTuple{N} ) where N = [x... ], V -> ntuple (i -> V[i], Val (N))
4
+ _generalized_logabsdet (A) = logabsdet (A)
5
+ _generalized_logabsdet (x:: Real ) = log (abs (x))
8
6
9
- function _to_realvec_and_back (x:: Ref )
10
- xval = x[]
11
- V, to_xval = _to_realvec_and_back (xval)
12
- back_to_ref (V) = Ref (to_xval (V))
13
- return (V, back_to_ref)
14
- end
15
-
16
- _to_realvec_and_back (A:: AbstractArray{<:Real} ) = vec (A), V -> reshape (V, size (A))
17
-
18
- function _to_realvec_and_back (A:: AbstractArray{Complex{T}, N} ) where {T<: Real , N}
19
- RA = cat (real .(A), imag .(A), dims = N+ 1 )
20
- V, to_array = _to_realvec_and_back (RA)
21
- function back_to_complex (V)
22
- RA = to_array (V)
23
- Complex .(view (RA, map (_ -> :, size (A))... , 1 ), view (RA, map (_ -> :, size (A))... , 2 ))
24
- end
25
- return (V, back_to_complex)
26
- end
27
-
28
-
29
- _to_realvec (x) = _to_realvec_and_back (x)[1 ]
30
-
31
-
32
- function _auto_with_logabsdet_jacobian (f, x, getjacobian)
7
+ function _auto_with_logabsdet_jacobian (f, x, getjacobian, rv_and_back)
33
8
y = f (x)
34
- V, to_x = _to_realvec_and_back (x)
35
- vf (V) = _to_realvec (f (to_x (V)))
36
- ladj = logabsdet (getjacobian (vf, V))[1 ]
9
+ V, to_x = rv_and_back (x)
10
+ vf (V) = rv_and_back (f (to_x (V)))[ 1 ]
11
+ ladj = _generalized_logabsdet (getjacobian (vf, V))[1 ]
37
12
return (y, ladj)
38
13
end
39
14
40
15
41
16
"""
42
17
ChangesOfVariables.test_with_logabsdet_jacobian(
43
- f, x, getjacobian;
44
- test_inferred::Bool = true, kwargs...
18
+ f, x, getjacobian, rv_and_back = x -> (x, identity) ;
19
+ compare = isapprox, test_inferred::Bool = true, kwargs...
45
20
)
46
21
47
22
Test if [`with_logabsdet_jacobian(f, x)`](@ref) is implemented correctly.
@@ -51,26 +26,35 @@ equal to `(f(x), logabsdet(getjacobian(f, x)))`
51
26
52
27
So the test uses `getjacobian(f, x)` to calculate a reference Jacobian for
53
28
`f` at `x`. Passing `ForwardDiff.jabobian`, `Zygote.jacobian` or similar as
54
- the `getjacobian` function will do fine in most cases.
29
+ the `getjacobian` function will do fine in most cases. If input and output
30
+ of `f` are real scalar values, use `ForwardDiff.derivative`.
31
+
32
+ If `getjacobian(f, x)` can't handle the type of `x` of `f(x)` because they
33
+ are not real-valued vectors, use the `rv_and_back` argument to pass a
34
+ function with the following behavior
55
35
56
- If `x` or `f(x)` are real-valued scalars or complex-valued scalars or arrays,
57
- the test will try to reshape them automatically, to account for limitations
58
- of (e.g.) `ForwardDiff` and to ensure the result of `getjacobian` is a real
59
- matrix.
36
+ ```julia
37
+ v, back = rv_and_back(x)
38
+ v isa AbstractVector{<:Real}
39
+ back(v) == x
40
+ ```
60
41
61
- If `test_inferred == true` will test type inference on
62
- `with_logabsdet_jacobian` .
42
+ If `test_inferred == true`, type inference on `with_logabsdet_jacobian` will
43
+ be tested .
63
44
64
- `kwargs...` are forwarded to `isapprox `.
45
+ `kwargs...` are forwarded to `compare `.
65
46
"""
66
- function test_with_logabsdet_jacobian (f, x, getjacobian; compare= isapprox, test_inferred:: Bool = true , kwargs... )
47
+ function test_with_logabsdet_jacobian (
48
+ f, x, getjacobian, rv_and_back = x -> (x, identity);
49
+ compare = isapprox, test_inferred:: Bool = true , kwargs...
50
+ )
67
51
@testset " test_with_logabsdet_jacobian: $f with input $x " begin
68
52
y, ladj = if test_inferred
69
53
@inferred with_logabsdet_jacobian (f, x)
70
54
else
71
55
with_logabsdet_jacobian (f, x)
72
56
end
73
- ref_y, ref_ladj = _auto_with_logabsdet_jacobian (f, x, getjacobian)
57
+ ref_y, ref_ladj = _auto_with_logabsdet_jacobian (f, x, getjacobian, rv_and_back )
74
58
@test compare (y, ref_y; kwargs... )
75
59
@test compare (ladj, ref_ladj; kwargs... )
76
60
end
0 commit comments