Skip to content

Commit dd295d6

Browse files
committed
Remove custom rrule for _with_ladj_on_mapped
Tricky to implement a correct rrule here that handles tangents which contain NoTangent or ZeroTangent.
1 parent 1c265b5 commit dd295d6

File tree

6 files changed

+8
-44
lines changed

6 files changed

+8
-44
lines changed

Project.toml

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,25 +3,15 @@ uuid = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
33
version = "0.1.6"
44

55
[deps]
6-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
76
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
87
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
98

10-
[weakdeps]
11-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
12-
13-
[extensions]
14-
ChangesOfVariablesChainRulesCoreExt = "ChainRulesCore"
15-
169
[compat]
17-
ChainRulesCore = "1"
1810
julia = "1"
1911

2012
[extras]
21-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
22-
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
2313
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
2414
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
2515

2616
[targets]
27-
test = ["ChainRulesCore", "ChainRulesTestUtils", "Documenter", "ForwardDiff"]
17+
test = ["Documenter", "ForwardDiff"]

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ changes for functions that perform a change of variables (like coordinate
1212
transformations).
1313

1414
`ChangesOfVariables` is a very lightweight package and has no dependencies
15-
beyond `Base`, `LinearAlgebra`, `Test` and `ChainRulesCore`.
15+
beyond `Base`, `LinearAlgebra`, `Test`.
1616

1717
## Documentation
1818

ext/ChangesOfVariablesChainRulesCoreExt.jl

Lines changed: 0 additions & 20 deletions
This file was deleted.

src/ChangesOfVariables.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,5 @@ using Test
1414

1515
include("with_ladj.jl")
1616
include("test.jl")
17-
if !isdefined(Base, :get_extension)
18-
include("../ext/ChangesOfVariablesChainRulesCoreExt.jl")
19-
end
2017

2118
end # module

src/with_ladj.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,13 @@ function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj::Tuple{Any,Real}) where
107107
return y_with_ladj
108108
end
109109

110+
_get_all_first(x) = map(first, x)
111+
# Use x -> x[2] instead of last, using last causes horrible performance in Zygote here:
112+
_sum_over_second(x) = sum(x -> x[2], x)
113+
110114
function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj) where {F<:Union{typeof(map),typeof(broadcast)}}
111-
y = map_or_bc(first, y_with_ladj)
112-
ladj = sum(last, y_with_ladj)
115+
y = _get_all_first(y_with_ladj)
116+
ladj = _sum_over_second(y_with_ladj)
113117
(y, ladj)
114118
end
115119

test/test_with_ladj.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@ using LinearAlgebra
77

88
using ChangesOfVariables
99
using ChangesOfVariables: test_with_logabsdet_jacobian
10-
using ChainRulesTestUtils
1110

1211
include("getjacobian.jl")
1312

@@ -66,10 +65,4 @@ include("getjacobian.jl")
6665
test_with_logabsdet_jacobian(f, x, getjacobian)
6766
end
6867
end
69-
70-
@testset "rrules" begin
71-
for map_or_bc in (map, broadcast)
72-
test_rrule(ChangesOfVariables._with_ladj_on_mapped, map_or_bc, [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)])
73-
end
74-
end
7568
end

0 commit comments

Comments
 (0)