|
1 | 1 | module LuxCoreEnzymeCoreExt |
2 | 2 |
|
3 | | -using EnzymeCore: EnzymeRules |
| 3 | +using EnzymeCore: EnzymeCore, EnzymeRules |
4 | 4 | using LuxCore: LuxCore |
5 | 5 | using Random: AbstractRNG |
6 | 6 |
|
7 | 7 | EnzymeRules.inactive(::typeof(LuxCore.replicate), ::AbstractRNG) = nothing |
8 | 8 |
|
| 9 | +# Handle common mistakes users might make |
| 10 | +const LAYER_DERIVATIVE_ERROR_MSG = """ |
| 11 | +Lux Layers only support `EnzymeCore.Const` annotation. |
| 12 | +
|
| 13 | +Lux Layers are immutable constants and gradients w.r.t. them are `nothing`. To |
| 14 | +compute the gradients w.r.t. the layer's parameters, use the first argument returned |
| 15 | +by `LuxCore.setup(rng, layer)` instead. |
| 16 | +""" |
| 17 | + |
| 18 | +function EnzymeCore.Active(::LuxCore.AbstractExplicitLayer) |
| 19 | + throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG)) |
| 20 | +end |
| 21 | + |
| 22 | +for annotation in (:Duplicated, :DuplicatedNoNeed) |
| 23 | + @eval function EnzymeCore.$(annotation)( |
| 24 | + ::LuxCore.AbstractExplicitLayer, ::LuxCore.AbstractExplicitLayer) |
| 25 | + throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG)) |
| 26 | + end |
| 27 | +end |
| 28 | + |
| 29 | +for annotation in (:BatchDuplicated, :BatchDuplicatedNoNeed) |
| 30 | + @eval function EnzymeCore.$(annotation)( |
| 31 | + ::LuxCore.AbstractExplicitLayer, ::NTuple{N, <:LuxCore.AbstractExplicitLayer}, |
| 32 | + check::Bool=true) where {N} |
| 33 | + throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG)) |
| 34 | + end |
| 35 | +end |
| 36 | + |
9 | 37 | end |
0 commit comments