Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit 1c4c2a6

Browse files
committed
feat: error on common mistakes
1 parent 0e59aa5 commit 1c4c2a6

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxCore"
22
uuid = "bb33d45b-7691-41d6-9220-0943567d0623"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "0.1.20"
4+
version = "0.1.21"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

ext/LuxCoreEnzymeCoreExt.jl

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,37 @@
11
module LuxCoreEnzymeCoreExt
22

3-
using EnzymeCore: EnzymeRules
3+
using EnzymeCore: EnzymeCore, EnzymeRules
44
using LuxCore: LuxCore
55
using Random: AbstractRNG
66

77
EnzymeRules.inactive(::typeof(LuxCore.replicate), ::AbstractRNG) = nothing
88

9+
# Handle common mistakes users might make
10+
const LAYER_DERIVATIVE_ERROR_MSG = """
11+
Lux Layers only support `EnzymeCore.Const` annotation.
12+
13+
Lux Layers are immutable constants and gradients w.r.t. them are `nothing`. To
14+
compute the gradients w.r.t. the layer's parameters, use the first argument returned
15+
by `LuxCore.setup(rng, layer)` instead.
16+
"""
17+
18+
function EnzymeCore.Active(::LuxCore.AbstractExplicitLayer)
19+
throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG))
20+
end
21+
22+
for annotation in (:Duplicated, :DuplicatedNoNeed)
23+
@eval function EnzymeCore.$(annotation)(
24+
::LuxCore.AbstractExplicitLayer, ::LuxCore.AbstractExplicitLayer)
25+
throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG))
26+
end
27+
end
28+
29+
for annotation in (:BatchDuplicated, :BatchDuplicatedNoNeed)
30+
@eval function EnzymeCore.$(annotation)(
31+
::LuxCore.AbstractExplicitLayer, ::NTuple{N, <:LuxCore.AbstractExplicitLayer},
32+
check::Bool=true) where {N}
33+
throw(ArgumentError(LAYER_DERIVATIVE_ERROR_MSG))
34+
end
35+
end
36+
937
end

0 commit comments

Comments
 (0)