Skip to content

Commit 5ef880b

Browse files
committed
Add custom rrule for _with_ladj_on_mapped
Speeds up AD of with_logabsdet_jacobian with mapped/broadcasted functions significantly.
1 parent 9d2a665 commit 5ef880b

File tree

4 files changed

+24
-1
lines changed

4 files changed

+24
-1
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,12 @@ uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
33
version = "0.1.1"
44

55
[deps]
6+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
67
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
78
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
89

910
[compat]
11+
ChainRulesCore = "1"
1012
julia = "1"
1113

1214
[extras]

src/ChangesOfVariables.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ transformations).
99
"""
1010
module ChangesOfVariables
1111

12+
using ChainRulesCore
1213
using LinearAlgebra
1314
using Test
1415

src/with_ladj.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,9 +84,19 @@ _with_ladj_on_mapped(map_or_bc::Function, y_with_ladj::Tuple{Any,Real}) = y_with
8484
function _with_ladj_on_mapped(map_or_bc::Function, y_with_ladj)
8585
y = map_or_bc(_get_y, y_with_ladj)
8686
ladj = sum(map_or_bc(_get_ladj, y_with_ladj))
87+
#ladj = sum(_get_ladj, y_with_ladj)
8788
(y, ladj)
8889
end
8990

91+
function _with_ladj_on_mapped_pullback(thunked_ΔΩ)
92+
ys, ladj = ChainRulesCore.unthunk(thunked_ΔΩ)
93+
NoTangent(), NoTangent(), broadcast(x -> (x, ladj), ys)
94+
end
95+
96+
function ChainRulesCore.rrule(::typeof(ChangesOfVariables._with_ladj_on_mapped), map_or_bc::Function, y_with_ladj)
97+
return ChangesOfVariables._with_ladj_on_mapped(map_or_bc, y_with_ladj), _with_ladj_on_mapped_pullback
98+
end
99+
90100
function with_logabsdet_jacobian(mapped_f::Base.Fix1{<:Union{typeof(map),typeof(broadcast)}}, X)
91101
map_or_bc = mapped_f.f
92102
f = mapped_f.x

test/test_with_ladj.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ using Test
55

66
using LinearAlgebra
77

8-
using ChangesOfVariables: test_with_logabsdet_jacobian
8+
using ChangesOfVariables: test_with_logabsdet_jacobian, _with_ladj_on_mapped
9+
using ChainRulesCore
910

1011
include("getjacobian.jl")
1112

@@ -59,4 +60,13 @@ include("getjacobian.jl")
5960
test_with_logabsdet_jacobian(f, x, getjacobian)
6061
end
6162
end
63+
64+
@testset "rrules" begin
65+
for map_or_bc in (map, broadcast)
66+
x = [(1, 2), (3, 4), (5, 6)]
67+
y, back = rrule(_with_ladj_on_mapped, map_or_bc, x)
68+
@test y == ([1, 3, 5], 12) == _with_ladj_on_mapped(map_or_bc, x)
69+
@test back(@thunk ([7, 8, 9], 12)) == (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), [(7, 12), (8, 12), (9, 12)])
70+
end
71+
end
6272
end

0 commit comments

Comments
 (0)