|
292 | 292 | end
|
293 | 293 | end
|
294 | 294 |
|
| 295 | +function EnzymeCore.EnzymeRules.augmented_primal(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, rng, dst::OutType, src, p, dims) where {OutType, RT} |
295 | 296 |
|
| 297 | + T = float(real(eltype(dst.val))) |
| 298 | + val = convert(T, 1/(1-p.val)) |
| 299 | + keep = if dims.val isa Colon |
| 300 | + similar(dst.val, T, size(dst.val)) |
| 301 | + else |
| 302 | + similar(dst.val, T, ntuple(d -> d in dims.val ? size(dst.val,d) : 1, ndims(dst.val))) |
| 303 | + end |
| 304 | + rand!(rng.val, keep) |
| 305 | + |
| 306 | + keep = keep .> p.val |
| 307 | + |
| 308 | + if OutType <: EnzymeCore.Duplicated || OutType <: EnzymeCore.BatchDuplicated |
| 309 | + dst.val .= (keep .* val) .* src.val |
| 310 | + end |
| 311 | + |
| 312 | + primal = if EnzymeCore.EnzymeRules.needs_primal(config) |
| 313 | + dst.val |
| 314 | + else |
| 315 | + nothing |
| 316 | + end |
| 317 | + shadow = if EnzymeCore.EnzymeRules.needs_shadow(config) |
| 318 | + dst.dval |
| 319 | + else |
| 320 | + nothing |
| 321 | + end |
| 322 | + |
| 323 | + if typeof(dst) <: EnzymeCore.Const || typeof(src) <: EnzymeCore.Const |
| 324 | + keep = nothing |
| 325 | + end |
| 326 | + |
| 327 | + # Cache idx if its overwritten |
| 328 | + cache_idx = ( EnzymeCore.EnzymeRules.overwritten(config)[4] |
| 329 | + && !(typeof(src) <: EnzymeCore.Const) |
| 330 | + && !(typeof(dst) <: EnzymeCore.Const) |
| 331 | + ) ? copy(idx.val) : nothing |
| 332 | + |
| 333 | + return EnzymeCore.EnzymeRules.AugmentedReturn(primal, shadow, keep) |
| 334 | +end |
| 335 | + |
| 336 | +function EnzymeCore.EnzymeRules.reverse(config, func::EnzymeCore.Const{typeof(NNlib._dropout!)}, ::Type{RT}, keep, rng, dst::OutType, src, p, dims) where {OutType, RT} |
| 337 | + T = float(real(eltype(dst.val))) |
| 338 | + val = convert(T, 1/(1-p.val)) |
| 339 | + |
| 340 | + ddsts = dst.dval |
| 341 | + dsrcs = src.dval |
| 342 | + |
| 343 | + if EnzymeCore.EnzymeRules.width(config) == 1 |
| 344 | + ddsts = (ddsts,) |
| 345 | + dsrcs = (dsrcs,) |
| 346 | + end |
| 347 | + |
| 348 | + for (ddst, dsrc) in zip(ddsts, dsrcs) |
| 349 | + if !(typeof(dst) <: EnzymeCore.Const) && ddst !== dst.val |
| 350 | + |
| 351 | + if !(typeof(src) <: EnzymeCore.Const) && dsrc !== src.val |
| 352 | + dsrc .+= (keep .* val) .* ddst |
| 353 | + end |
| 354 | + |
| 355 | + ddst .= 0 |
| 356 | + end |
| 357 | + end |
| 358 | + |
| 359 | + dp = if typeof(p) <: EnzymeCore.Active |
| 360 | + typeof(p.val)(0) |
| 361 | + else |
| 362 | + nothing |
| 363 | + end |
| 364 | + |
| 365 | + return (nothing, nothing, nothing, dp, nothing) |
| 366 | +end |
0 commit comments