Skip to content

Commit 59ba92c

Browse files
authored
Merge pull request #153 from JuliaDiff/ox/tangentforward
Fix getindex on CompositeBundle of a struct
2 parents caad122 + b845365 commit 59ba92c

File tree

2 files changed

+21
-4
lines changed

2 files changed

+21
-4
lines changed

src/tangent.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -296,12 +296,9 @@ CompositeBundle{N, B}(tup::T) where {N, B, T} = CompositeBundle{N, B, T}(tup)
296296

297297
function Base.getindex(tb::CompositeBundle{N, B} where N, tti::TaylorTangentIndex) where {B}
298298
B <: SArray && error()
299-
Tangent{B}(map(tb.tup) do el
300-
el[tti]
301-
end...)
299+
return partial(tb, tti.i)
302300
end
303301

304-
305302
primal(b::CompositeBundle{N, <:Tuple} where N) = map(primal, b.tup)
306303
function primal(b::CompositeBundle{N, T} where N) where T<:CompositeBundle
307304
T(map(primal, b.tup)...)

test/tangent.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
module tagent
22
using Diffractor
33
using Diffractor: AbstractZeroBundle, ZeroBundle, DNEBundle
4+
using Diffractor: TaylorBundle, TaylorTangentIndex, CompositeBundle
5+
using ChainRulesCore
46
using Test
57

68
@testset "AbstractZeroBundle" begin
@@ -24,4 +26,22 @@ using Test
2426
end
2527
end
2628

29+
@testset "AD through constructor" begin
30+
#https://github.com/JuliaDiff/Diffractor.jl/issues/152
31+
# hits `getindex(::CompositeBundle{Foo152}, ::TaylorTangentIndex)`
32+
struct Foo152
33+
x::Float64
34+
end
35+
36+
# Unit Test
37+
cb = CompositeBundle{1, Foo152}((TaylorBundle{1, Float64}(23.5, (1.0,)),))
38+
tti = TaylorTangentIndex(1,)
39+
@test cb[tti] == Tangent{Foo152}(; x=1.0)
40+
41+
# Integration Test
42+
var"'" = Diffractor.PrimeDerivativeFwd
43+
f(x) = Foo152(x)
44+
@test f'(23.5) == Tangent{Foo152}(; x=1.0)
45+
end
46+
2747
end # module

0 commit comments

Comments
 (0)