Skip to content

Commit 7d252ad

Browse files
authored
feat: support nested tuples of arrays as Caches (#748)
* fix: separe `prepare` from the hidden `prepare_nokwarg` * DOcs * Typing * Fix * Toggle fail fast * feat: recursive similar for caches * Recursive caches * Enzyme * Remove new tests * SCT fix * Nesting in test scens * More sophisticated testing * Fix * Coverage
1 parent e1d171f commit 7d252ad

File tree

25 files changed

+115
-52
lines changed

25 files changed

+115
-52
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.45"
4+
version = "0.6.46"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceEnzymeExt/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ force_annotation(f::F) where {F} = Const(f)
5454
end
5555

5656
@inline function _translate(
57-
backend::AutoEnzyme, ::Mode, ::Val{B}, c::DI.GeneralizedCache
57+
backend::AutoEnzyme, ::Mode, ::Val{B}, c::Union{DI.Cache,DI.PrepContext}
5858
) where {B}
5959
if B == 1
6060
return Duplicated(DI.unwrap(c), make_zero(DI.unwrap(c)))

DifferentiationInterface/ext/DifferentiationInterfaceForwardDiffExt/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ function _translate(
8989
end
9090
function _translate(::Type{D}, c::DI.Cache) where {D<:Dual}
9191
c0 = DI.unwrap(c)
92-
return similar(c0, D)
92+
return DI.recursive_similar(c0, D)
9393
end
9494

9595
function translate(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C}
@@ -106,7 +106,7 @@ function _translate_toprep(
106106
end
107107
function _translate_toprep(::Type{D}, c::DI.Cache) where {D<:Dual}
108108
c0 = DI.unwrap(c)
109-
return similar(c0, D)
109+
return DI.recursive_similar(c0, D)
110110
end
111111

112112
function translate_toprep(::Type{D}, contexts::NTuple{C,DI.Context}) where {D<:Dual,C}

DifferentiationInterface/ext/DifferentiationInterfaceSparseConnectivityTracerExt/DifferentiationInterfaceSparseConnectivityTracerExt.jl

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,26 +5,23 @@ import DifferentiationInterface as DI
55
using SparseConnectivityTracer:
66
TracerSparsityDetector, TracerLocalSparsityDetector, jacobian_buffer, hessian_buffer
77

8-
@inline _jacobian_translate(detector, c::DI.Constant) = DI.unwrap(c)
9-
@inline function _jacobian_translate(detector, c::DI.Cache{<:AbstractArray})
10-
return jacobian_buffer(DI.unwrap(c), detector)
8+
@inline _translate(::Type, c::DI.Constant) = DI.unwrap(c)
9+
@inline function _translate(::Type{T}, c::DI.Cache) where {T}
10+
return DI.recursive_similar(DI.unwrap(c), T)
1111
end
1212

13-
function jacobian_translate(detector, contexts::Vararg{DI.Context,C}) where {C}
13+
function jacobian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C}
14+
T = eltype(jacobian_buffer(x, detector))
1415
new_contexts = map(contexts) do c
15-
_jacobian_translate(detector, c)
16+
_translate(T, c)
1617
end
1718
return new_contexts
1819
end
1920

20-
@inline _hessian_translate(detector, c::DI.Constant) = DI.unwrap(c)
21-
@inline function _hessian_translate(detector, c::DI.Cache{<:AbstractArray})
22-
return hessian_buffer(DI.unwrap(c), detector)
23-
end
24-
25-
function hessian_translate(detector, contexts::Vararg{DI.Context,C}) where {C}
21+
function hessian_translate(detector, x, contexts::Vararg{DI.Context,C}) where {C}
22+
T = eltype(hessian_buffer(x, detector))
2623
new_contexts = map(contexts) do c
27-
_hessian_translate(detector, c)
24+
_translate(T, c)
2825
end
2926
return new_contexts
3027
end
@@ -35,7 +32,7 @@ function DI.jacobian_sparsity_with_contexts(
3532
x,
3633
contexts::Vararg{DI.Context,C},
3734
) where {F,C}
38-
contexts_tracer = jacobian_translate(detector, contexts...)
35+
contexts_tracer = jacobian_translate(detector, x, contexts...)
3936
fc = DI.FixTail(f, contexts_tracer...)
4037
return jacobian_sparsity(fc, x, detector)
4138
end
@@ -47,7 +44,7 @@ function DI.jacobian_sparsity_with_contexts(
4744
x,
4845
contexts::Vararg{DI.Context,C},
4946
) where {F,C}
50-
contexts_tracer = jacobian_translate(detector, contexts...)
47+
contexts_tracer = jacobian_translate(detector, x, contexts...)
5148
fc! = DI.FixTail(f!, contexts_tracer...)
5249
return jacobian_sparsity(fc!, y, x, detector)
5350
end
@@ -58,7 +55,7 @@ function DI.hessian_sparsity_with_contexts(
5855
x,
5956
contexts::Vararg{DI.Context,C},
6057
) where {F,C}
61-
contexts_tracer = hessian_translate(detector, contexts...)
58+
contexts_tracer = hessian_translate(detector, x, contexts...)
6259
fc = DI.FixTail(f, contexts_tracer...)
6360
return hessian_sparsity(fc, x, detector)
6461
end

DifferentiationInterface/ext/DifferentiationInterfaceZygoteExt/DifferentiationInterfaceZygoteExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,10 @@ DI.check_available(::AutoZygote) = true
1717
DI.inplace_support(::AutoZygote) = DI.InPlaceNotSupported()
1818

1919
translate(c::DI.Context) = DI.unwrap(c)
20-
translate(c::DI.Cache) = Buffer(DI.unwrap(c))
20+
translate(c::DI.Cache{<:AbstractArray}) = Buffer(DI.unwrap(c))
21+
function translate(c::DI.Cache{<:Union{Tuple,NamedTuple}})
22+
return map(translate, map(DI.Cache, DI.unwrap(c)))
23+
end
2124

2225
## Pullback
2326

DifferentiationInterface/src/utils/context.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ Abstract supertype for additional context arguments, which can be passed to diff
2323
abstract type Context end
2424

2525
abstract type GeneralizedConstant <: Context end
26-
abstract type GeneralizedCache <: Context end
2726

2827
unwrap(c::Context) = c.data
2928
Base.:(==)(c1::Context, c2::Context) = unwrap(c1) == unwrap(c2)
@@ -78,7 +77,7 @@ The initial values present inside the cache do not matter.
7877
For some backends, preparation allocates the required memory for `Cache` contexts with the right element type, similar to [PreallocationTools.jl](https://github.com/SciML/PreallocationTools.jl).
7978
8079
!!! warning
81-
Most backends require any `Cache` context to be an `AbstractArray`.
80+
Some backends require any `Cache` context to be an `AbstractArray`, others accept nested (named) tuples of `AbstractArray`s.
8281
8382
# Example
8483
@@ -97,7 +96,7 @@ julia> gradient(f, prep, AutoForwardDiff(), [3.0, 4.0], Cache(zeros(2)))
9796
1.0
9897
````
9998
"""
100-
struct Cache{T} <: GeneralizedCache
99+
struct Cache{T} <: Context
101100
data::T
102101
end
103102

@@ -114,12 +113,10 @@ struct BackendContext{T} <: GeneralizedConstant
114113
data::T
115114
end
116115

117-
struct PrepContext{T} <: GeneralizedCache
116+
struct PrepContext{T} <: Context
118117
data::T
119118
end
120119

121-
struct UnknownContext <: Context end
122-
123120
## Context manipulation
124121

125122
struct Rewrap{C,T}
@@ -146,4 +143,4 @@ function with_contexts(f::F, contexts::Vararg{Context,N}) where {F,N}
146143
end
147144

148145
adapt_eltype(c::Constant, ::Type) = c
149-
adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(similar(unwrap(c), T))
146+
adapt_eltype(c::Cache, ::Type{T}) where {T} = Cache(recursive_similar(unwrap(c), T))

DifferentiationInterface/src/utils/linalg.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,15 @@ At the moment, this only returns `false` for `StaticArrays.SArray`.
1010
"""
1111
ismutable_array(::Type) = true
1212
ismutable_array(x) = ismutable_array(typeof(x))
13+
14+
"""
15+
recursive_similar(x, T)
16+
17+
Apply `similar(_, T)` recursively to `x` or its components.
18+
19+
Works if `x` is an `AbstractArray` or a (nested) `NTuple` / `NamedTuple` of `AbstractArray`s.
20+
"""
21+
recursive_similar(x::AbstractArray, ::Type{T}) where {T} = similar(x, T)
22+
function recursive_similar(x::Union{Tuple,NamedTuple}, ::Type{T}) where {T}
23+
return map(xi -> recursive_similar(xi, T), x)
24+
end

DifferentiationInterface/test/Back/Enzyme/test.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ end;
5555

5656
test_differentiation(
5757
backends[2],
58-
default_scenarios(; include_normal=false, include_cachified=true);
58+
default_scenarios(; include_normal=false, include_cachified=true, use_tuples=true);
5959
excluded=SECOND_ORDER,
6060
logging=LOGGING,
6161
)

DifferentiationInterface/test/Back/FiniteDiff/test.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ end
2222
@testset "Dense" begin
2323
test_differentiation(
2424
AutoFiniteDiff(),
25-
default_scenarios(; include_constantified=true, include_cachified=true);
25+
default_scenarios(;
26+
include_constantified=true, include_cachified=true, use_tuples=true
27+
);
2628
excluded=[:second_derivative, :hvp],
2729
logging=LOGGING,
2830
)

DifferentiationInterface/test/Back/FiniteDifferences/test.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,9 @@ end
1919

2020
test_differentiation(
2121
AutoFiniteDifferences(; fdm=FiniteDifferences.central_fdm(3, 1)),
22-
default_scenarios(; include_constantified=true, include_cachified=true);
22+
default_scenarios(;
23+
include_constantified=true, include_cachified=true, use_tuples=true
24+
);
2325
excluded=SECOND_ORDER,
2426
logging=LOGGING,
2527
);

0 commit comments

Comments
 (0)