Skip to content

Commit d819f29

Browse files
committed
Use Tangent in _with_ladj_on_mapped pullback
1 parent 534a4f2 commit d819f29

File tree

3 files changed

+14
-10
lines changed

3 files changed

+14
-10
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,9 @@ ChainRulesCore = "1"
1212
julia = "1"
1313

1414
[extras]
15+
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
1516
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
1617
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1718

1819
[targets]
19-
test = ["Documenter", "ForwardDiff"]
20+
test = ["ChainRulesTestUtils", "Documenter", "ForwardDiff"]

src/with_ladj.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,18 @@ function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj) where {F<:Union{typeof(
8686
(y, ladj)
8787
end
8888

89-
function _with_ladj_on_mapped_pullback(thunked_ΔΩ)
89+
90+
# Need to use a type for this, type inference fails when using a pullback
91+
# closure over YLT in the rrule, resulting in bad performance:
92+
struct WithLadjOnMappedPullback{YLT} <: Function end
93+
function (::WithLadjOnMappedPullback{YLT})(thunked_ΔΩ) where YLT
9094
ys, ladj = unthunk(thunked_ΔΩ)
91-
return NoTangent(), NoTangent(), tuple.(ys, ladj)
95+
return NoTangent(), NoTangent(), broadcast((y, l) -> Tangent{YLT}(y, l), ys, ladj)
9296
end
9397

9498
function ChainRulesCore.rrule(::typeof(_with_ladj_on_mapped), map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}}
95-
return _with_ladj_on_mapped(map_or_bc, y_with_ladj), _with_ladj_on_mapped_pullback
99+
YLT = eltype(y_with_ladj)
100+
return _with_ladj_on_mapped(map_or_bc, y_with_ladj), WithLadjOnMappedPullback{YLT}()
96101
end
97102

98103
function with_logabsdet_jacobian(mapped_f::Base.Fix1{<:Union{typeof(map),typeof(broadcast)}}, X)

test/test_with_ladj.jl

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ using Test
55

66
using LinearAlgebra
77

8-
using ChangesOfVariables: test_with_logabsdet_jacobian, _with_ladj_on_mapped
9-
using ChainRulesCore
8+
using ChangesOfVariables
9+
using ChangesOfVariables: test_with_logabsdet_jacobian
10+
using ChainRulesTestUtils
1011

1112
include("getjacobian.jl")
1213

@@ -63,10 +64,7 @@ include("getjacobian.jl")
6364

6465
@testset "rrules" begin
6566
for map_or_bc in (map, broadcast)
66-
x = [(1, 2), (3, 4), (5, 6)]
67-
y, back = rrule(_with_ladj_on_mapped, map_or_bc, x)
68-
@test y == ([1, 3, 5], 12) == _with_ladj_on_mapped(map_or_bc, x)
69-
@test back(@thunk ([7, 8, 9], 12)) == (ChainRulesCore.NoTangent(), ChainRulesCore.NoTangent(), [(7, 12), (8, 12), (9, 12)])
67+
test_rrule(ChangesOfVariables._with_ladj_on_mapped, map_or_bc, [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)])
7068
end
7169
end
7270
end

0 commit comments

Comments
 (0)