Skip to content

Commit f09a1df

Browse files
committed
Support Broadcast.BroadcastFunction
1 parent dd295d6 commit f09a1df

File tree

3 files changed

+22
-4
lines changed

3 files changed

+22
-4
lines changed

docs/src/index.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,6 @@ ChangesOfVariables
66

77
This package defines the function [`with_logabsdet_jacobian`](@ref). `(y, ladj) = with_logabsdet_jacobian(f, x)` computes both the transformed value of `x` under the transformation `f` and the logarithm of the [volume element](https://en.wikipedia.org/wiki/Volume_element).
88

9-
`with_logabsdet_jacobian` supports mapped/broadcasted functions (via `Base.Fix1`) and (on Julia >=v1.6) function composition.
9+
`with_logabsdet_jacobian` supports mapped/broadcasted functions (via `Base.Broadcast.BroadcastFunction` or `Base.Fix1`) and (on Julia >=v1.6) function composition.
1010

1111
Implementations of `with_logabsdet_jacobian(f)` for `identity`, `inv`, `adjoint` and `transpose` as well as for `exp`, `log`, `exp2`, `log2`, `exp10`, `log10`, `expm1` and `log1p` are included.

src/with_ladj.jl

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ For `(y, ladj) = with_logabsdet_jacobian(f, x)`, the following must hold true:
1212
* `ladj` is the `log(abs(det(jacobian(f, x))))`
1313
1414
`with_logabsdet_jacobian` comes with support for broadcasted/mapped functions
15-
(via `Base.Fix1`) and (Julia >=v1.6 only) `ComposedFunction`.
15+
(via `Base.Broadcast.BroadcastFunction` or `Base.Fix1`) and (Julia >=v1.6 only)
16+
`ComposedFunction`.
1617
1718
If no volume element is defined/applicable, `with_logabsdet_jacobian(f::F, x::T)`
1819
returns [`NoLogAbsDetJacobian{F,T}()`](@ref).
@@ -43,7 +44,11 @@ true
4344
4445
```jldoctest a
4546
X = rand(10)
46-
broadcasted_foo = Base.Fix1(broadcast, foo)
47+
broadcasted_foo = if VERSION >= v"1.6"
48+
Base.Broadcast.BroadcastFunction(foo)
49+
else
50+
Base.Fix1(broadcast, foo)
51+
end
4752
Y, ladj_Y = with_logabsdet_jacobian(broadcasted_foo, X)
4853
Y == broadcasted_foo(X) && ladj_Y ≈ logabsdet(ForwardDiff.jacobian(broadcasted_foo, X))[1]
4954
@@ -117,6 +122,13 @@ function _with_ladj_on_mapped(map_or_bc::F, y_with_ladj) where {F<:Union{typeof(
117122
(y, ladj)
118123
end
119124

125+
@static if VERSION >= v"1.6"
126+
function with_logabsdet_jacobian(mapped_f::Base.Broadcast.BroadcastFunction, X)
127+
f = mapped_f.f
128+
y_with_ladj = broadcast(Base.Fix1(with_logabsdet_jacobian, f), X)
129+
_with_ladj_on_mapped(broadcast, y_with_ladj)
130+
end
131+
end
120132

121133
function with_logabsdet_jacobian(mapped_f::Base.Fix1{<:Union{typeof(map),typeof(broadcast)}}, X)
122134
map_or_bc = mapped_f.f

test/test_with_ladj.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,12 @@ include("getjacobian.jl")
1212

1313

1414
@testset "with_logabsdet_jacobian" begin
15+
@static if VERSION >= v"1.6"
16+
_bc_func(f) = Base.Broadcast.BroadcastFunction(f)
17+
else
18+
_bc_func(f) = Base.Fix1(broadcast, f)
19+
end
20+
1521
@test with_logabsdet_jacobian(sum, rand(5)) == NoLogAbsDetJacobian{typeof(sum),Vector{Float64}}()
1622
@test with_logabsdet_jacobian(sum log, 5.0f0) == NoLogAbsDetJacobian{typeof(sum ∘ log),Float32}()
1723
@test with_logabsdet_jacobian(log sum, 5.0f0) == NoLogAbsDetJacobian{typeof(log ∘ sum),Float32}()
@@ -39,7 +45,7 @@ include("getjacobian.jl")
3945
end
4046

4147
@testset "with_logabsdet_jacobian on mapped and broadcasted" begin
42-
for f in (Base.Fix1(map, foo), Base.Fix1(broadcast, foo))
48+
for f in (_bc_func(foo), Base.Fix1(map, foo), Base.Fix1(broadcast, foo))
4349
for arg in (x, fill(x,), Ref(x), (x,), X)
4450
test_with_logabsdet_jacobian(f, arg, getjacobian, compare = isaprx)
4551
end

0 commit comments

Comments
 (0)