@@ -334,10 +334,6 @@ function unbundle(atb::TaylorBundle{Order, A}) where {Order, Dim, T, A<:Abstract
334
334
StructArray {TaylorBundle{Order, T}} ((atb. primal, atb. tangent. coeffs... ))
335
335
end
336
336
337
- function ChainRulesCore. rrule (:: typeof (unbundle), atb:: TaylorBundle )
338
- unbundle (atb), Δ-> throw (Δ)
339
- end
340
-
341
337
function StructArrays. staticschema (:: Type{<:TaylorBundle{N, B, T}} ) where {N, B, T}
342
338
Tuple{B, T. parameters... }
343
339
end
@@ -355,11 +351,11 @@ function StructArrays.createinstance(T::Type{<:TaylorBundle}, args...)
355
351
T (first (args), Base. tail (args))
356
352
end
357
353
358
- function unbundle (zb :: ZeroBundle {N, A} ) where {N,T,Dim,A<: AbstractArray{T, Dim} }
359
- StructArray {ZeroBundle {N, T}} ((zb . primal, fill (zb . tangent. val, size (zb . primal)... )))
354
+ function unbundle (u :: UniformBundle {N, A} ) where {N,T,Dim,A<: AbstractArray{T, Dim} }
355
+ StructArray {UniformBundle {N, T}} ((u . primal, fill (u . tangent. val, size (u . primal)... )))
360
356
end
361
357
362
- function ChainRulesCore. rrule (:: typeof (unbundle), atb:: ZeroBundle )
358
+ function ChainRulesCore. rrule (:: typeof (unbundle), atb:: AbstractTangentBundle )
363
359
unbundle (atb), Δ-> throw (Δ)
364
360
end
365
361
@@ -383,6 +379,11 @@ function rebundle(A::AbstractArray{<:TaylorBundle{N}}) where {N}
383
379
end )
384
380
end
385
381
382
+ function rebundle (A:: AbstractArray{<:UniformBundle{N}} ) where {N}
383
+ @assert all (x-> getfield (x, :tangent )== (first (A). tangent), A)
384
+ UniformBundle {N} (map (x-> x. primal, A), first (A). tangent. val)
385
+ end
386
+
386
387
function ChainRulesCore. rrule (:: typeof (rebundle), atb)
387
388
rebundle (atb), Δ-> throw (Δ)
388
389
end
0 commit comments