Skip to content

Commit b6b3569

Browse files
authored
pirate errors for two mistakes (#1976)
1 parent f2ecdf6 commit b6b3569

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

src/Flux.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,15 @@ using MacroTools: @forward
77

88
@reexport using NNlib
99
using MLUtils
10-
import Optimisers: trainable, destructure # before v0.13, Flux owned these functions
10+
import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions
1111

1212
using Zygote, ChainRulesCore
1313
using Zygote: Params, @adjoint, gradient, pullback, @nograd
1414
export gradient
1515

16+
# Pirate error to catch a common mistake. (Internal function `base` because overloading `update!` is more likely to give ambiguities.)
17+
Optimisers.base(dx::Zygote.Grads) = error("Optimisers.jl cannot be used with Zygote.jl's implicit gradients, `Params` & `Grads`")
18+
1619
export Chain, Dense, Maxout, SkipConnection, Parallel,
1720
RNN, LSTM, GRU, GRUv3,
1821
SamePad, Conv, CrossCor, ConvTranspose, DepthwiseConv,
@@ -38,6 +41,9 @@ include("utils.jl")
3841
include("onehot.jl")
3942
include("functor.jl")
4043

44+
# Pirate error to catch a common mistake.
45+
Functors.functor(::Type{<:MLUtils.DataLoader}, x) = error("`DataLoader` does not support Functors.jl, thus functions like `Flux.gpu` will not act on its contents.")
46+
4147
include("layers/stateless.jl")
4248
include("layers/basic.jl")
4349
include("layers/conv.jl")

0 commit comments

Comments
 (0)