|
1 |
| -using MonotonicSplines |
2 | 1 | # a new implementation of Neural Spline Flow based on MonotonicSplines.jl
|
3 | 2 | # the construction of the RQS seems to be more efficient than the one in Bijectors.jl
|
4 | 3 | # and supports batched operations.
|
5 | 4 |
|
6 | 5 | """
|
7 |
| -Neural Rational quadratic Spline layer |
| 6 | +Neural Rational Quadratic Spline Coupling layer |
8 | 7 | # References
|
9 | 8 | [1] Durkan, C., Bekasov, A., Murray, I., & Papamakarios, G., Neural Spline Flows, CoRR, arXiv:1906.04032 [stat.ML], (2019).
|
10 | 9 | """
|
|
41 | 40 |
|
42 | 41 | function get_nsc_params(nsc::NeuralSplineCoupling, x::AbstractVecOrMat)
|
43 | 42 | nnoutput = nsc.nn(x)
|
44 |
| - px, py, dydx = MonotonicSplines.rqs_params_from_nn(nnoutput, nsc.n_dims_transferred, nsc.B) |
| 43 | + px, py, dydx = MonotonicSplines.rqs_params_from_nn( |
| 44 | + nnoutput, nsc.n_dims_transferred, nsc.B |
| 45 | + ) |
45 | 46 | return px, py, dydx
|
46 | 47 | end
|
47 | 48 |
|
48 | 49 | # when input x is a vector instead of a matrix
|
49 | 50 | # need this to transform it to a matrix with one row
|
50 | 51 | # otherwise, rqs_forward and rqs_inverse will throw an error
|
51 |
| -_ensure_matrix(x) = x isa AbstractVector ? reshape(x, 1, length(x)) : x |
| 52 | +_ensure_matrix(x) = x isa AbstractVector ? reshape(x, length(x), 1) : x |
52 | 53 |
|
53 |
| -function Bijectors.transform(nsc::NeuralSplineCoupling, x::AbstractVecOrMat) |
| 54 | +function Bijectors.transform(nsc::NeuralSplineCoupling, x::AbstractVector) |
54 | 55 | x1, x2, x3 = Bijectors.partition(nsc.mask, x)
|
55 | 56 | # instantiate rqs knots and derivatives
|
56 | 57 | px, py, dydx = get_nsc_params(nsc, x2)
|
57 | 58 | x1 = _ensure_matrix(x1)
|
58 | 59 | y1, _ = MonotonicSplines.rqs_forward(x1, px, py, dydx)
|
| 60 | + return Bijectors.combine(nsc.mask, vec(y1), x2, x3) |
| 61 | +end |
| 62 | +function Bijectors.transform(nsc::NeuralSplineCoupling, x::AbstractMatrix) |
| 63 | + x1, x2, x3 = Bijectors.partition(nsc.mask, x) |
| 64 | + # instantiate rqs knots and derivatives |
| 65 | + px, py, dydx = get_nsc_params(nsc, x2) |
| 66 | + y1, _ = MonotonicSplines.rqs_forward(x1, px, py, dydx) |
59 | 67 | return Bijectors.combine(nsc.mask, y1, x2, x3)
|
60 | 68 | end
|
61 | 69 |
|
62 |
| -function Bijectors.with_logabsdet_jacobian(nsc::NeuralSplineCoupling, x::AbstractVecOrMat) |
| 70 | +function Bijectors.with_logabsdet_jacobian(nsc::NeuralSplineCoupling, x::AbstractVector) |
63 | 71 | x1, x2, x3 = Bijectors.partition(nsc.mask, x)
|
64 | 72 | # instantiate rqs knots and derivatives
|
65 | 73 | px, py, dydx = get_nsc_params(nsc, x2)
|
66 | 74 | x1 = _ensure_matrix(x1)
|
67 | 75 | y1, logjac = MonotonicSplines.rqs_forward(x1, px, py, dydx)
|
68 |
| - return Bijectors.combine(nsc.mask, y1, x2, x3), logjac isa Real ? logjac : vec(logjac) |
| 76 | + return Bijectors.combine(nsc.mask, vec(y1), x2, x3), logjac[1] |
| 77 | +end |
| 78 | +function Bijectors.with_logabsdet_jacobian(nsc::NeuralSplineCoupling, x::AbstractMatrix) |
| 79 | + x1, x2, x3 = Bijectors.partition(nsc.mask, x) |
| 80 | + # instantiate rqs knots and derivatives |
| 81 | + px, py, dydx = get_nsc_params(nsc, x2) |
| 82 | + y1, logjac = MonotonicSplines.rqs_forward(x1, px, py, dydx) |
| 83 | + return Bijectors.combine(nsc.mask, y1, x2, x3), vec(logjac) |
69 | 84 | end
|
70 | 85 |
|
71 |
| -function Bijectors.transform(insl::Inverse{<:NeuralSplineCoupling}, y::AbstractVecOrMat) |
| 86 | +function Bijectors.transform(insl::Inverse{<:NeuralSplineCoupling}, y::AbstractVector) |
72 | 87 | nsc = insl.orig
|
73 | 88 | y1, y2, y3 = partition(nsc.mask, y)
|
74 | 89 | px, py, dydx = get_nsc_params(nsc, y2)
|
75 | 90 | y1 = _ensure_matrix(y1)
|
76 | 91 | x1, _ = MonotonicSplines.rqs_inverse(y1, px, py, dydx)
|
| 92 | + return Bijectors.combine(nsc.mask, vec(x1), y2, y3) |
| 93 | +end |
| 94 | +function Bijectors.transform(insl::Inverse{<:NeuralSplineCoupling}, y::AbstractMatrix) |
| 95 | + nsc = insl.orig |
| 96 | + y1, y2, y3 = partition(nsc.mask, y) |
| 97 | + px, py, dydx = get_nsc_params(nsc, y2) |
| 98 | + x1, _ = MonotonicSplines.rqs_inverse(y1, px, py, dydx) |
77 | 99 | return Bijectors.combine(nsc.mask, x1, y2, y3)
|
78 | 100 | end
|
79 | 101 |
|
80 |
| -function Bijectors.with_logabsdet_jacobian(insl::Inverse{<:NeuralSplineCoupling}, y::AbstractVecOrMat) |
| 102 | +function Bijectors.with_logabsdet_jacobian(insl::Inverse{<:NeuralSplineCoupling}, y::AbstractVector) |
81 | 103 | nsc = insl.orig
|
82 | 104 | y1, y2, y3 = partition(nsc.mask, y)
|
83 | 105 | px, py, dydx = get_nsc_params(nsc, y2)
|
84 | 106 | y1 = _ensure_matrix(y1)
|
85 | 107 | x1, logjac = MonotonicSplines.rqs_inverse(y1, px, py, dydx)
|
86 |
| - return Bijectors.combine(nsc.mask, x1, y2, y3), logjac isa Real ? logjac : vec(logjac) |
| 108 | + return Bijectors.combine(nsc.mask, vec(x1), y2, y3), logjac[1] |
| 109 | +end |
| 110 | +function Bijectors.with_logabsdet_jacobian(insl::Inverse{<:NeuralSplineCoupling}, y::AbstractMatrix) |
| 111 | + nsc = insl.orig |
| 112 | + y1, y2, y3 = partition(nsc.mask, y) |
| 113 | + px, py, dydx = get_nsc_params(nsc, y2) |
| 114 | + x1, logjac = MonotonicSplines.rqs_inverse(y1, px, py, dydx) |
| 115 | + return Bijectors.combine(nsc.mask, x1, y2, y3), vec(logjac) |
87 | 116 | end
|
88 | 117 |
|
89 | 118 | function (nsc::NeuralSplineCoupling)(x::AbstractVecOrMat)
|
|
0 commit comments