Skip to content

Commit 0ccb6c9

Browse files
committed
Add dropout
1 parent c32de74 commit 0ccb6c9

File tree

1 file changed

+71
-0
lines changed

1 file changed

+71
-0
lines changed

src/enzyme.jl

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,4 +292,75 @@ end
292292
end
293293
end
294294

295+
function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT}
295296

297+
T = float(real(eltype(dst.val)))
298+
val = convert(T, 1/(1-p.val))
299+
keep = if dims.val isa Colon
300+
similar(dst.val, T, size(dst.val))
301+
else
302+
similar(dst.val, T, ntuple(d -> d in dims.val ? size(dst.val,d) : 1, ndims(dst.val)))
303+
end
304+
rand!(rng.val, keep)
305+
306+
keep = keep .> p.val
307+
308+
if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated
309+
dst.val .= (keep .* val) .* src.val
310+
end
311+
312+
primal = if EnzymeCore.EnzymeRules.needs_primal(config)
313+
dst.val
314+
else
315+
nothing
316+
end
317+
shadow = if EnzymeCore.EnzymeRules.needs_shadow(config)
318+
dst.dval
319+
else
320+
nothing
321+
end
322+
323+
if typeof(dst) <: EnzymeCore.Const || typeof(src) <: EnzymeCore.Const
324+
keep = nothing
325+
end
326+
327+
# Cache idx if its overwritten
328+
cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4]
329+
&& !(typeof(src) <: EnzymeCore.Const)
330+
&& !(typeof(dst) <: EnzymeCore.Const)
331+
) ? copy(idx.val) : nothing
332+
333+
return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep)
334+
end
335+
336+
function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT}
337+
T = float(real(eltype(dst.val)))
338+
val = convert(T, 1/(1-p.val))
339+
340+
ddsts = dst.dval
341+
dsrcs = src.dval
342+
343+
if EnzymeCore.EnzymeRules.width(config) == 1
344+
ddsts = (ddsts,)
345+
dsrcs = (dsrcs,)
346+
end
347+
348+
for (ddst, dsrc) in zip(ddsts, dsrcs)
349+
if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val
350+
351+
if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val
352+
dsrc .+= (keep .* val) .* ddst
353+
end
354+
355+
ddst .= 0
356+
end
357+
end
358+
359+
dp = if typeof(p) <: EnzymeCore.Active
360+
typeof(p.val)(0)
361+
else
362+
nothing
363+
end
364+
365+
return (nothing, nothing, nothing, dp, nothing)
366+
end

0 commit comments

Comments
 (0)