Skip to content

Commit a7b1f11

Browse files
hstreyhelmutstreyavik-pal
authored
feat: support for ForwardDiff training (#1273)
* added extension for ForwardDiff * moved compute_gradients_imp ForwardDiff dispatch to /helpers/training/jl * removed LuxForwardDiffExt from Project.toml * Update src/helpers/training.jl * Update src/helpers/training.jl * added test for ForwardDiff training * removed () * created new testitem for ForwardDiff and added ForwardDiff Limitation to docstring * added test condition at the end of ForwardDiff test, and reduced reduced function calls * feat: use caching to reduce memory allocations * Apply suggestions from code review --------- Co-authored-by: Helmut Strey <Helmut.Strey@stonybrook.edu> Co-authored-by: Avik Pal <avik.pal.2017@gmail.com> Co-authored-by: Avik Pal <avikpal@mit.edu>
1 parent d5dfb0c commit a7b1f11

File tree

5 files changed

+150
-2
lines changed

5 files changed

+150
-2
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.10.1"
4+
version = "1.11.0"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -11,6 +11,7 @@ ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
1111
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1212
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
1313
ConcreteStructs = "2569d6c7-a4a2-43d3-a901-331e8e4be471"
14+
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
1415
DispatchDoctor = "8d63f2c5-f18a-4cf2-ba9d-b3f60fc568c8"
1516
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1617
FastClosures = "9aa1b823-49e4-5ca5-8b0f-3971ec8bab6a"
@@ -82,6 +83,7 @@ ChainRulesCore = "1.25"
8283
Compat = "4.16"
8384
ComponentArrays = "0.15.22"
8485
ConcreteStructs = "0.2.3"
86+
DiffResults = "1.1"
8587
DispatchDoctor = "0.4.12"
8688
Enzyme = "0.13.35"
8789
EnzymeCore = "0.8.8"

src/Lux.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ include("extended_ops.jl")
8282
# Training Helpers
8383
include("helpers/optimizers.jl")
8484
include("helpers/training.jl")
85+
include("helpers/forwarddiff_training.jl")
8586

8687
# Experimental
8788
include("contrib/contrib.jl")
@@ -155,7 +156,8 @@ export Training
155156

156157
export jacobian_vector_product, vector_jacobian_product
157158
export batched_jacobian
158-
export AutoEnzyme, AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote
159+
export AutoEnzyme,
160+
AutoForwardDiff, AutoReverseDiff, AutoTracker, AutoZygote, AutoForwardDiff
159161

160162
export BinaryCrossEntropyLoss,
161163
BinaryFocalLoss,
Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
using ADTypes: AutoForwardDiff
2+
using DiffResults: DiffResults
3+
using ForwardDiff: ForwardDiff
4+
using Setfield: @set!
5+
using Static: True, False
6+
7+
function Training.compute_gradients_impl(
8+
ad::AutoForwardDiff, obj_fn::F, data, ts::Training.TrainState
9+
) where {F}
10+
@assert ts.parameters isa AbstractArray "AutoForwardDiff only supports AbstractArray \
11+
parameters, not $(typeof(ts.parameters)). To \
12+
convert the parameter structure to an array \
13+
use `ComponentArray(ps)`."
14+
15+
obj_fn_wrap, st_wrap, stats_wrap = Training.wrap_objective_function(
16+
obj_fn, ts.model, ts.parameters, ts.states, data, True()
17+
)
18+
19+
gradient_result = DiffResults.GradientResult(ts.parameters)
20+
ForwardDiff.gradient!(
21+
gradient_result, ps -> obj_fn_wrap(ts.model, ps, ts.states, data), ts.parameters
22+
)
23+
24+
cache = Training.TrainingBackendCache(
25+
ad, False(), gradient_result, (; obj_fn=obj_fn_wrap, st_wrap, stats_wrap)
26+
)
27+
@set! ts.cache = cache
28+
@set! ts.objective_function = obj_fn
29+
@set! ts.states = st_wrap[]
30+
return (
31+
DiffResults.gradient(gradient_result),
32+
DiffResults.value(gradient_result),
33+
stats_wrap[],
34+
ts,
35+
)
36+
end
37+
38+
const FORWARDDIFF_CACHE_TYPE = Training.TrainingBackendCache{
39+
<:AutoForwardDiff,False,PS,<:NamedTuple{(:obj_fn, :st_wrap, :stats_wrap)}
40+
} where {PS}
41+
42+
function Training.compute_gradients_impl(
43+
::AutoForwardDiff, obj_fn::F, data, ts::Training.TrainState{<:FORWARDDIFF_CACHE_TYPE,F}
44+
) where {F}
45+
gradient_result = ts.cache.dparameters
46+
47+
ForwardDiff.gradient!(
48+
gradient_result,
49+
ps -> ts.cache.extras.obj_fn(ts.model, ps, ts.states, data),
50+
ts.parameters,
51+
)
52+
53+
@set! ts.objective_function = obj_fn
54+
@set! ts.states = ts.cache.extras.st_wrap[]
55+
56+
return (
57+
DiffResults.gradient(gradient_result),
58+
DiffResults.value(gradient_result),
59+
ts.cache.extras.stats_wrap[],
60+
ts,
61+
)
62+
end
63+
64+
function Training.compute_gradients_impl(
65+
::AutoForwardDiff,
66+
obj_fn::F,
67+
data,
68+
ts::Training.TrainState{<:Training.TrainingBackendCache{<:AutoForwardDiff,False}},
69+
) where {F}
70+
@warn "Detected calls to `compute_gradients(::AutoForwardDiff, ...)` with objective \
71+
function that is changing across function calls. This can lead to the \
72+
generation of slow code" maxlog = 1
73+
gradient_result = ts.cache.dparameters
74+
75+
# We do exactly same thing as the first case but without caching the function
76+
obj_fn_wrap, st_wrap, stats_wrap = Training.wrap_objective_function(
77+
obj_fn, ts.model, ts.parameters, ts.states, data, False()
78+
)
79+
80+
ForwardDiff.gradient!(
81+
gradient_result, ps -> obj_fn_wrap(ts.model, ps, ts.states, data), ts.parameters
82+
)
83+
84+
@set! ts.states = st_wrap[]
85+
return (
86+
DiffResults.gradient(gradient_result),
87+
DiffResults.value(gradient_result),
88+
stats_wrap[],
89+
ts,
90+
)
91+
end

src/helpers/training.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ Compute the gradients of the objective function wrt parameters stored in `ts`.
160160
| `AutoReverseDiff(; compile)` | `ReverseDiff.jl` |
161161
| `AutoTracker` | `Tracker.jl` |
162162
| `AutoEnzyme` | `Enzyme.jl` |
163+
| `AutoForwardDiff` | |
163164
164165
## Arguments
165166
@@ -185,6 +186,8 @@ A 4-Tuple containing:
185186
- `AutoReverseDiff(; compile=true)` is not supported for Lux models with non-empty state
186187
`st`. Additionally the returned stats must be empty (`NamedTuple()`). We catch these
187188
issues in most cases and throw an error.
189+
- AutoForwardDiff only works with parameters that are AbstractArrays
190+
(e.g. ps=ComponentVector(ps))
188191
189192
!!! danger "Aliased Gradients"
190193

test/helpers/training_tests.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,56 @@ end
139139
end
140140
end
141141

142+
@testitem "Training API ForwardDiff" setup = [SharedTestSetup] tags = [:misc] begin
143+
using ADTypes, Optimisers, ComponentArrays
144+
145+
mse = MSELoss()
146+
147+
rng = StableRNG(12345)
148+
149+
x_data = randn(rng, Float32, 4, 32)
150+
y_data = evalpoly.(x_data, ((1, 2, 3),)) .- evalpoly.(x_data, ((5, 2),))
151+
y_data = (y_data .- minimum(y_data)) ./ (maximum(y_data) - minimum(y_data))
152+
dataset = [(x_data[:, i], y_data[:, i]) for i in Iterators.partition(1:32, 8)]
153+
154+
model = Chain(
155+
Dense(4, 32, tanh), BatchNorm(32), Dense(32, 32, tanh), BatchNorm(32), Dense(32, 4)
156+
)
157+
158+
dataset_ = [(x, y) for (x, y) in dataset]
159+
opt = Adam(0.001f0)
160+
161+
ps, st = Lux.setup(rng, model)
162+
tstate = Training.TrainState(model, ComponentVector(ps), st, opt)
163+
164+
initial_loss = first(
165+
mse(model, tstate.parameters, Lux.testmode(tstate.states), dataset_[1])
166+
)
167+
168+
for epoch in 1:100, (x, y) in dataset_
169+
grads, loss, _, tstate = allow_unstable() do
170+
Training.compute_gradients(AutoForwardDiff(), mse, (x, y), tstate)
171+
end
172+
tstate = Training.apply_gradients!(tstate, grads)
173+
end
174+
175+
for epoch in 1:100, (x, y) in dataset_
176+
grads, loss, _, tstate = allow_unstable() do
177+
Training.single_train_step!(AutoForwardDiff(), mse, (x, y), tstate)
178+
end
179+
end
180+
181+
for epoch in 1:100, (x, y) in dataset_
182+
grads, loss, _, tstate = allow_unstable() do
183+
Training.single_train_step(AutoForwardDiff(), mse, (x, y), tstate)
184+
end
185+
end
186+
187+
final_loss = first(mse(model, tstate.parameters, tstate.states, dataset_[1]))
188+
189+
@test final_loss * 50 < initial_loss
190+
end
191+
142192
@testitem "Enzyme: Invalidate Cache on State Update" setup = [SharedTestSetup] tags = [
143193
:misc
144194
] skip = :(using LuxTestUtils; !LuxTestUtils.ENZYME_TESTING_ENABLED) begin

0 commit comments

Comments
 (0)