Skip to content

Commit 3c8c393

Browse files
committed
Move ChainRulesCore support to extension
1 parent 0418307 commit 3c8c393

File tree

7 files changed

+44
-21
lines changed

7 files changed

+44
-21
lines changed

Project.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ authors = ["Chad Scherrer <[email protected]>", "Oliver Schulz <oschulz@mp
44
version = "0.14.11"
55

66
[deps]
7-
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
87
ChangesOfVariables = "9e997f8a-9a97-42d5-a9f1-ce6bfc15e2c0"
98
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
109
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
@@ -29,6 +28,12 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2928
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3029
Tricks = "410a4b4d-49e4-4fbc-ab6d-cb71b17b3775"
3130

31+
[weakdeps]
32+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
33+
34+
[extensions]
35+
MeasureBaseChainRulesCoreExt = "ChainRulesCore"
36+
3237
[compat]
3338
ChainRulesCore = "1"
3439
ChangesOfVariables = "0.1.3"

ext/MeasureBaseChainRulesCoreExt.jl

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
# This file is a part of MeasureBase.jl, licensed under the MIT License (MIT).
2+
3+
module MeasureBaseChainRulesCoreExt
4+
5+
using MeasureBase
6+
using ChainRulesCore: NoTangent, ZeroTangent
7+
import ChainRulesCore
8+
9+
10+
@inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result)
11+
y = _checksupport(cond, result)
12+
function _checksupport_pullback(ȳ)
13+
return NoTangent(), ZeroTangent(), one(ȳ)
14+
end
15+
y, _checksupport_pullback
16+
end
17+
18+
19+
_require_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent()
20+
function ChainRulesCore.rrule(::typeof(require_insupport), μ, x)
21+
return require_insupport(μ, x), _require_insupport_pullback
22+
end
23+
24+
25+
_origin_depth_pullback(ΔΩ) = NoTangent(), NoTangent()
26+
ChainRulesCore.rrule(::typeof(_origin_depth), ν) = _origin_depth(ν), _origin_depth_pullback
27+
28+
29+
_check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent()
30+
ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback
31+
32+
33+
_checked_arg_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
34+
ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checked_arg_pullback
35+
36+
37+
end # module MeasureBaseChainRulesCoreExt

src/MeasureBase.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ using IntervalSets
3131
using PrettyPrinting
3232
const Pretty = PrettyPrinting
3333

34-
using ChainRulesCore
3534
import FillArrays
3635
using Static
3736
using Static: StaticInteger

src/density-core.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,6 @@ end
3535

3636
_checksupport(cond, result) = ifelse(cond == true, result, oftype(result, -Inf))
3737

38-
import ChainRulesCore
39-
@inline function ChainRulesCore.rrule(::typeof(_checksupport), cond, result)
40-
y = _checksupport(cond, result)
41-
function _checksupport_pullback(ȳ)
42-
return NoTangent(), ZeroTangent(), one(ȳ)
43-
end
44-
y, _checksupport_pullback
45-
end
4638

4739
export unsafe_logdensityof
4840

src/getdof.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ end
5151
_check_dof_pullback(ΔΩ) = NoTangent(), NoTangent(), NoTangent()
5252
ChainRulesCore.rrule(::typeof(check_dof), ν, μ) = check_dof(ν, μ), _check_dof_pullback
5353

54+
5455
"""
5556
MeasureBase.NoArgCheck{MU,T}
5657
@@ -78,6 +79,3 @@ end
7879
@propagate_inbounds function checked_arg(mu::MU, x) where {MU}
7980
_default_checked_arg(MU, basemeasure(mu), x)
8081
end
81-
82-
_checked_arg_pullback(ΔΩ) = NoTangent(), NoTangent(), ΔΩ
83-
ChainRulesCore.rrule(::typeof(checked_arg), ν, x) = checked_arg(ν, x), _checked_arg_pullback

src/insupport.jl

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,6 @@ Checks if `x` is in the support of distribution/measure `μ`, throws an
1818
"""
1919
function require_insupport end
2020

21-
_require_insupport_pullback(ΔΩ) = NoTangent(), ZeroTangent()
22-
function ChainRulesCore.rrule(::typeof(require_insupport), μ, x)
23-
return require_insupport(μ, x), _require_insupport_pullback
24-
end
25-
2621
function require_insupport(μ, x)
2722
if !insupport(μ, x)
2823
throw(ArgumentError("x is not within the support of μ"))

src/transport.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,9 +135,6 @@ end
135135
return static(10)
136136
end
137137

138-
_origin_depth_pullback(ΔΩ) = NoTangent(), NoTangent()
139-
ChainRulesCore.rrule(::typeof(_origin_depth), ν) = _origin_depth(ν), _origin_depth_pullback
140-
141138
# If both both measures have no origin:
142139
function _transport_between_origins(ν, ::StaticInteger{0}, ::StaticInteger{0}, μ, x)
143140
_transport_with_intermediate(ν, _transport_intermediate(ν, μ), μ, x)

0 commit comments

Comments
 (0)