|
81 | 81 |
|
82 | 82 | _with_ladj_on_mapped(map_or_bc::Function, y_with_ladj::Tuple{Any,Real}) = y_with_ladj
|
83 | 83 |
|
84 |
| -function _with_ladj_on_mapped(map_or_bc::Function, y_with_ladj) |
85 |
| - y = map_or_bc(_get_y, y_with_ladj) |
86 |
| - ladj = sum(map_or_bc(_get_ladj, y_with_ladj)) |
| 84 | +function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}} |
| 85 | + y = map_or_bc(first, y_with_ladj) |
| 86 | + ladj = sum(map_or_bc(last, y_with_ladj)) |
87 | 87 | #ladj = sum(_get_ladj, y_with_ladj)
|
88 | 88 | (y, ladj)
|
89 | 89 | end
|
90 | 90 |
|
91 | 91 | function _with_ladj_on_mapped_pullback(thunked_ΔΩ)
|
92 |
| - ys, ladj = ChainRulesCore.unthunk(thunked_ΔΩ) |
93 |
| - NoTangent(), NoTangent(), broadcast(x -> (x, ladj), ys) |
| 92 | + ys, ladj = unthunk(thunked_ΔΩ) |
| 93 | + return NoTangent(), NoTangent(), tuple.(ys, ladj) |
94 | 94 | end
|
95 | 95 |
|
96 |
| -function ChainRulesCore.rrule(::typeof(ChangesOfVariables._with_ladj_on_mapped), map_or_bc::Function, y_with_ladj) |
97 |
| - return ChangesOfVariables._with_ladj_on_mapped(map_or_bc, y_with_ladj), _with_ladj_on_mapped_pullback |
| 96 | +function ChainRulesCore.rrule(::typeof(_with_ladj_on_mapped), map_or_bc::Function, y_with_ladj) |
| 97 | + return _with_ladj_on_mapped(map_or_bc, y_with_ladj), _with_ladj_on_mapped_pullback |
98 | 98 | end
|
99 | 99 |
|
100 | 100 | function with_logabsdet_jacobian(mapped_f::Base.Fix1{<:Union{typeof(map),typeof(broadcast)}}, X)
|
|
0 commit comments