Skip to content

Commit 51ac353

Browse files
authored
feat: make FromPrimitive wrappers public (#825)
1 parent dae09ef commit 51ac353

File tree

8 files changed

+61
-19
lines changed

8 files changed

+61
-19
lines changed

DifferentiationInterface/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
## [Unreleased]
99

10+
### Added
11+
12+
- Make `AutoForwardFromPrimitive` and `AutoReverseFromPrimitive` public ([#825])
13+
1014
### Fixed
1115

1216
- Replace `one` with `oneunit` in basis computation ([#826])
@@ -67,6 +71,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
6771
[0.6.53]: https://github.com/JuliaDiff/DifferentiationInterface.jl/compare/DifferentiationInterface-v0.6.52...DifferentiationInterface-v0.6.53
6872

6973
[#826]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/826
74+
[#825]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/825
7075
[#823]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/823
7176
[#818]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/818
7277
[#812]: https://github.com/JuliaDiff/DifferentiationInterface.jl/pull/812

DifferentiationInterface/docs/src/api.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,13 @@ MixedMode
132132
DenseSparsityDetector
133133
```
134134

135+
### From primitive
136+
137+
```@docs
138+
DifferentiationInterface.AutoForwardFromPrimitive
139+
DifferentiationInterface.AutoReverseFromPrimitive
140+
```
141+
135142
## Internals
136143

137144
The following is not part of the public API.

DifferentiationInterface/src/DifferentiationInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,7 @@ export AutoSparse
126126
## Public but not exported
127127

128128
@public inner, outer
129+
@public AutoForwardFromPrimitive, AutoReverseFromPrimitive
129130

130131
include("init.jl")
131132

DifferentiationInterface/src/first_order/mixed_mode.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,3 +41,10 @@ Appropriate mode type for `MixedMode` backends.
4141
"""
4242
struct ForwardAndReverseMode <: ADTypes.AbstractMode end
4343
ADTypes.mode(::MixedMode) = ForwardAndReverseMode()
44+
45+
function threshold_batchsize(backend::MixedMode, B::Integer)
46+
return MixedMode(
47+
threshold_batchsize(forward_backend(backend), B),
48+
threshold_batchsize(reverse_backend(backend), B),
49+
)
50+
end

DifferentiationInterface/src/misc/from_primitive.jl

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,26 +3,45 @@ abstract type FromPrimitive{inplace} <: AbstractADType end
33
check_available(backend::FromPrimitive) = check_available(backend.backend)
44
inplace_support(::FromPrimitive{true}) = InPlaceSupported()
55
inplace_support(::FromPrimitive{false}) = InPlaceNotSupported()
6-
function inner_preparation_behavior(backend::FromPrimitive)
7-
return inner_preparation_behavior(backend.backend)
6+
7+
function pick_batchsize(backend::FromPrimitive, x_or_y::AbstractArray)
8+
return pick_batchsize(backend.backend, x_or_y)
89
end
910

1011
function pick_batchsize(backend::FromPrimitive, N::Integer)
1112
return pick_batchsize(backend.backend, N)
1213
end
1314

15+
function inner_preparation_behavior(backend::FromPrimitive)
16+
return inner_preparation_behavior(backend.backend)
17+
end
18+
19+
function overloaded_input(::typeof(pushforward), f, backend::FromPrimitive, x, tx::NTuple)
20+
return overloaded_input(pushforward, f, backend.backend, x, tx)
21+
end
22+
23+
function overloaded_input(
24+
::typeof(pushforward), f!, y, backend::FromPrimitive, x, tx::NTuple
25+
)
26+
return overloaded_input(pushforward, f!, y, backend.backend, x, tx)
27+
end
28+
1429
"""
15-
AutoForwardFromPrimitive
30+
AutoForwardFromPrimitive(backend::AbstractADType)
1631
17-
Wrapper which forces a given backend to act as a reverse-mode backend.
32+
Wrapper which forces a given backend to act as a forward-mode backend, using only its native `value_and_pushforward` primitive and re-implementing the rest from scratch.
1833
19-
Used in internal testing.
34+
!!! tip
35+
This can be useful to circumvent high-level operators when they have impractical limitations.
36+
For instance, ForwardDiff.jl's `jacobian` does not support GPU arrays but its `pushforward` does, so `AutoForwardFromPrimitive(AutoForwardDiff())` has a GPU-friendly `jacobian`.
2037
"""
2138
struct AutoForwardFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace}
2239
backend::B
2340
end
2441

25-
function AutoForwardFromPrimitive(backend::AbstractADType; inplace=true)
42+
function AutoForwardFromPrimitive(
43+
backend::AbstractADType; inplace::Bool=Bool(inplace_support(backend))
44+
)
2645
return AutoForwardFromPrimitive{inplace,typeof(backend)}(backend)
2746
end
2847

@@ -133,17 +152,17 @@ function value_and_pushforward!(
133152
end
134153

135154
"""
136-
AutoReverseFromPrimitive
137-
138-
Wrapper which forces a given backend to act as a reverse-mode backend.
155+
AutoReverseFromPrimitive(backend::AbstractADType)
139156
140-
Used in internal testing.
157+
Wrapper which forces a given backend to act as a reverse-mode backend, using only its native `value_and_pullback` implementation and rebuilding the rest from scratch.
141158
"""
142159
struct AutoReverseFromPrimitive{inplace,B<:AbstractADType} <: FromPrimitive{inplace}
143160
backend::B
144161
end
145162

146-
function AutoReverseFromPrimitive(backend::AbstractADType; inplace=true)
163+
function AutoReverseFromPrimitive(
164+
backend::AbstractADType; inplace::Bool=Bool(inplace_support(backend))
165+
)
147166
return AutoReverseFromPrimitive{inplace,typeof(backend)}(backend)
148167
end
149168

DifferentiationInterface/src/utils/batchsize.jl

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,13 +112,6 @@ function threshold_batchsize(backend::SecondOrder, B::Integer)
112112
)
113113
end
114114

115-
function threshold_batchsize(backend::MixedMode, B::Integer)
116-
return MixedMode(
117-
threshold_batchsize(forward_backend(backend), B),
118-
threshold_batchsize(reverse_backend(backend), B),
119-
)
120-
end
121-
122115
"""
123116
reasonable_batchsize(N::Integer, Bmax::Integer)
124117

DifferentiationInterface/test/Back/ForwardDiff/test.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import DifferentiationInterface as DI
88
import DifferentiationInterfaceTest as DIT
99
using ForwardDiff: ForwardDiff
1010
using StaticArrays: StaticArrays, @SVector
11+
using JLArrays: JLArrays
1112
using Test
1213

1314
using ExplicitImports
@@ -75,6 +76,9 @@ end
7576
@testset "Weird" begin
7677
test_differentiation(AutoForwardDiff(), component_scenarios(); logging=LOGGING)
7778
test_differentiation(AutoForwardDiff(), static_scenarios(); logging=LOGGING)
79+
test_differentiation(
80+
DI.AutoForwardFromPrimitive(AutoForwardDiff()), gpu_scenarios(); logging=LOGGING
81+
)
7882

7983
@testset "Batch size" begin
8084
@test DI.pick_batchsize(AutoForwardDiff(), rand(7)) isa DI.BatchSizeSettings{7}

DifferentiationInterface/test/Core/SimpleFiniteDiff/test.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,12 @@ end
141141

142142
@testset "Weird arrays" begin
143143
test_differentiation(
144-
AutoSimpleFiniteDiff(), vcat(static_scenarios(), gpu_scenarios()); logging=LOGGING
144+
[
145+
AutoSimpleFiniteDiff(),
146+
AutoForwardFromPrimitive(AutoSimpleFiniteDiff()),
147+
AutoReverseFromPrimitive(AutoSimpleFiniteDiff()),
148+
],
149+
vcat(static_scenarios(), gpu_scenarios());
150+
logging=LOGGING,
145151
)
146152
end;

0 commit comments

Comments
 (0)