Skip to content

Commit b6c80ba

Browse files
committed
Add dropout
1 parent 1bb3081 commit b6c80ba

File tree

2 files changed

+97
-1
lines changed

2 files changed

+97
-1
lines changed

src/enzyme.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,4 +292,75 @@ end
292292
end
293293
end
294294

295+
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT}
295296

297+
T = float(real(eltype(dst.val)))
298+
val = convert(T, 1/(1-p.val))
299+
keep = if dims.val isa Colon
300+
similar(dst.val, T, size(dst.val))
301+
else
302+
similar(dst.val, T, ntuple(d -> d in dims.val ? size(dst.val,d) : 1, ndims(dst.val)))
303+
end
304+
rand!(rng.val, keep)
305+
306+
keep = keep .> p.val
307+
308+
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
309+
dst.val .= (keep .* val) .* src.val
310+
end
311+
312+
primal = if EnzymeCore.EnzymeRules.needs_primal(config)
313+
dst.val
314+
else
315+
nothing
316+
end
317+
shadow = if EnzymeCore.EnzymeRules.needs_shadow(config)
318+
dst.dval
319+
else
320+
nothing
321+
end
322+
323+
if typeof(dst) <: EnzymeCore.Const || typeof(src) <: EnzymeCore.Const
324+
keep = nothing
325+
end
326+
327+
# Cache idx if its overwritten
328+
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4]
329+
&& !(typeof(src) <: EnzymeCore.Const)
330+
&& !(typeof(dst) <: EnzymeCore.Const)
331+
) ? copy(idx.val) : nothing
332+
333+
return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep)
334+
end
335+
336+
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT}
337+
T = float(real(eltype(dst.val)))
338+
val = convert(T, 1/(1-p.val))
339+
340+
ddsts = dst.dval
341+
dsrcs = src.dval
342+
343+
if EnzymeCore.EnzymeRules.width(config) == 1
344+
ddsts = (ddsts,)
345+
dsrcs = (dsrcs,)
346+
end
347+
348+
for (ddst, dsrc) in zip(ddsts, dsrcs)
349+
if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val
350+
351+
if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val
352+
dsrc .+= (keep .* val) .* ddst
353+
end
354+
355+
ddst .= 0
356+
end
357+
end
358+
359+
dp = if typeof(p) <: EnzymeCore.Active
360+
typeof(p.val)(0)
361+
else
362+
nothing
363+
end
364+
365+
return (nothing, nothing, nothing, dp, nothing)
366+
end

test/dropout.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
using NNlib, Test, Statistics, Random, LinearAlgebra
2-
using Zygote, StableRNGs, ChainRulesCore
2+
using Zygote, StableRNGs, ChainRulesCore, Enzyme
33

44
@testset "dropout" begin
55
# Basics
@@ -75,3 +75,28 @@ using Zygote, StableRNGs, ChainRulesCore
7575
@test_throws ArgumentError dropout(x1, 2)
7676
@test_throws ArgumentError dropout!(y1, x1, 3)
7777
end
78+
79+
@testset "EnzymeRules: dropout "
80+
rng = Random.default_rng()
81+
82+
x1 = randn(Float32, 3000, 4000)
83+
dx1 = zeros(Float32, 3000, 4000)
84+
85+
dout = randn(Float32, 3000, 4000)
86+
87+
p = 0.2f0
88+
89+
forward, reverse = Enzyme.autodiff_thunk(ReverseSplitWithPrimal, typeof(Const(dropout)), Duplicated, typeof(Const(rng)), typeof(Duplicated(x1, dx1)), typeof(Const(0.2f0)))
90+
91+
tape, primal, shadow = forward(Const(dropout), Const(rng), Duplicated(x1, dx1), Const(p))
92+
93+
shadow .= dout
94+
95+
reverse(Const(dropout), Const(rng), Duplicated(x1, dx1), Const(p), tape)
96+
97+
@test dx1[.!tape[1]] zero(x1)[.!tape[1]]
98+
99+
val = convert(Float32, 1/(1-p))
100+
101+
@test dx1[tape[1]] (val * dout)[tape[1]]
102+
end

0 commit comments

Comments
 (0)