diff --git a/docs/src/reference.md b/docs/src/reference.md index 5edde719..8ed28b8a 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -164,3 +164,10 @@ NNlib.glu NNlib.within_gradient bias_act! ``` + +Finally, this switch changes warnings on various fallback paths into errors. +It's a bit like `CUDA.allowscalar(false)`. + +```@docs +allowslow +``` diff --git a/src/NNlib.jl b/src/NNlib.jl index 687206fc..cb82b751 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -18,6 +18,15 @@ using Statistics: mean const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number} +""" + allowslow(::Bool) + +By default, NNlib will print warnings the first time various slow fallback paths are taken. +Calling `allowslow(false)` will instead make these into errors. +""" +allowslow(flag::Bool) = (SLOWERROR[] = !flag; nothing) +const SLOWERROR = Ref(false) + # Include APIs include("dim_helpers.jl") export ConvDims, DenseConvDims, PoolDims, DepthwiseConvDims diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index ccd9b0e8..c71ed8c3 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -274,7 +274,10 @@ for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST size(A, 3) == size(C, 3) || size(A, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != C")) size(B, 3) == size(C, 3) || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: B != C")) - @debug "calling fallback method for batched_mul!" typeof(A) size(A) typeof(B) size(B) typeof(C) + @warn "calling fallback method for batched_mul!" typeof(A) size(A) typeof(B) size(B) typeof(C) maxlog=1 + if SLOWERROR[] + error("calling fallback method for batched_mul!") + end Abase, Bbase = _unbatch(A), _unbatch(B) sA, oA = size(A,3) == 1 ? (0,1) : (1,0) diff --git a/src/conv.jl b/src/conv.jl index fead2ee2..4d67f4a4 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -191,6 +191,7 @@ for (front_name, backend, signature) in ( if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 + SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name)))) end x_cs = Iterators.partition(1:size(in1, 4), @@ -232,6 +233,7 @@ for (front_name, backend, signature) in ( if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 + SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name)))) end @@ -275,6 +277,7 @@ for (front_name, backend, signature) in ( if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 + SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name)))) end dw_cs = Iterators.partition(1:size(out, 5), @@ -326,6 +329,7 @@ for (front_name, backend, signature) in ( if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 + SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name)))) end $(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...) end diff --git a/src/fold.jl b/src/fold.jl index f3c205e1..6594aeb0 100644 --- a/src/fold.jl +++ b/src/fold.jl @@ -16,35 +16,35 @@ and a potential inverse of `unfold`. The below example demonstrates that `unfold` uses the same sliding windows as `conv`. In general [`batched_mul`](@ref) + `unfold` should not be used to achieve convolution. ```jldoctest -julia> x = reshape([100 2 3 40 5 6 700], 7, 1, 1); # 1D data, 1 channel, batch of 1 +julia> x = reshape(Float32[100 2 3 40 5 6 700], 7, 1, 1); # 1D data, 1 channel, batch of 1 -julia> w = reshape([1 0 -1], 3, 1, 1); # 1D conv kernel of length 3 +julia> w = reshape(Float32[1 0 -1], 3, 1, 1); # 1D conv kernel of length 3 julia> kws = (pad=1, stride=2, flipped=true); # use same args for conv and unfold julia> z = NNlib.unfold(x, size(w); kws...) -4×3×1 Array{Int64, 3}: +4×3×1 Array{Float32, 3}: [:, :, 1] = - 0 100 2 - 2 3 40 - 40 5 6 - 6 700 0 + 0.0 100.0 2.0 + 2.0 3.0 40.0 + 40.0 5.0 6.0 + 6.0 700.0 0.0 julia> y1 = conv(x, w; kws...) -4×1×1 Array{Int64, 3}: +4×1×1 Array{Float32, 3}: [:, :, 1] = - -2 - -38 - 34 - 6 + -2.0 + -38.0 + 34.0 + 6.0 julia> y2 = z ⊠ w # ⊠ (\\boxtimes) is NNlib.batched_mul -4×1×1 Array{Int64, 3}: +4×1×1 Array{Float32, 3}: [:, :, 1] = - -2 - -38 - 34 - 6 + -2.0 + -38.0 + 34.0 + 6.0 ``` """ function unfold(x::AbstractArray{T, N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = true) where {T, K, N} diff --git a/test/batchedmul.jl b/test/batchedmul.jl index 1b8b08e1..b249112c 100644 --- a/test/batchedmul.jl +++ b/test/batchedmul.jl @@ -303,3 +303,14 @@ FiniteDifferences.to_vec(x::BatchedTranspose) = FiniteDifferences.to_vec(collect gradtest(batched_vec, randn(rng, M, P, B), randn(rng, P)) end + +@testset "warning / error" begin + prev = NNlib.SLOWERROR[] + NNlib.allowslow(true) + A = rand(1:99, 3,4,7) + B = rand(1:99, 4,5,7) + @test batched_mul(A, B) isa Array # no error! + NNlib.allowslow(false) + @test_throws Exception batched_mul(A, B) + NNlib.SLOWERROR[] = prev +end