Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ ScopedValues = "1.3.0"
SpecialFunctions = "2"
Statistics = "1"
cuDNN = "1"
julia = "1.9"
julia = "1.10"
2 changes: 1 addition & 1 deletion src/dim_helpers/ConvDims.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ function im2col_dims(c::ConvDims)
# Size of single dotproduct within convolution
prod(kernel_size(c))*channels_in(c),
# One workspace per thread
VERSION > v"1.9.0-0" ? Threads.nthreads(:default) : Threads.nthreads(),
Threads.nthreads(:default),
)
end

Expand Down
2 changes: 1 addition & 1 deletion src/gemm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ for (gemm, elt) in gemm_datatype_mappings
strC = Base.stride(C, 3)

n_threads = min(
VERSION > v"1.9.0-0" ? Threads.nthreads(:default) : Threads.nthreads(),
Threads.nthreads(:default),
1 + max(length(A), length(B)) ÷ 8000)
# In some tests, size (20,20,20) is worth splitting between two threads,
# as is size (32,32,8).
Expand Down
18 changes: 0 additions & 18 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,21 +144,3 @@ function rrule(cfg::RuleConfig{>:HasReverseMode}, ::typeof(_fast_broadcast!), f:
rrule_via_ad(cfg, broadcast, f, x, ys...)
end

# Could get this from Compat.jl instead
# https://github.com/JuliaLang/julia/pull/39794
if VERSION < v"1.7.0-DEV.793"
struct Returns{V} <: Function
value::V
Returns{V}(value) where {V} = new{V}(value)
Returns(value) = new{Core.Typeof(value)}(value)
end

(obj::Returns)(args...; kw...) = obj.value
function Base.show(io::IO, obj::Returns)
show(io, typeof(obj))
print(io, "(")
show(io, obj.value)
print(io, ")")
end
end

11 changes: 11 additions & 0 deletions test.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import Metal, NNlib, Flux

dev = Flux.get_device()

src, idx = Int32[1 2 3 4; 5 6 7 8], Int32[2,1,1,5]
srcd, idxd = dev(x), dev(idx)
y = NNlib.scatter(+, src, idx)
yd = dev(zero(y))
NNlib.scatter!(+, yd, srcd, idxd)


2 changes: 1 addition & 1 deletion test/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -908,7 +908,7 @@ end
gradtest((y, w) -> sum(∇depthwiseconv_data(y, w, dcdims)), y, w)
end

@static if Test_Enzyme
if NNLIB_TEST_ENZYME

@testset "EnzymeRules: conv! spatial_rank=$spatial_rank" for spatial_rank in (1, 2, 3)
x = rand(rng, repeat([5], spatial_rank)..., 3, 2)
Expand Down
5 changes: 1 addition & 4 deletions test/dropout.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,6 @@ using Zygote, StableRNGs, ChainRulesCore, Enzyme
@test size(@inferred dropout(rng, x1, 0.1; dims=2)) == (3, 4)

x2 = Diagonal(randn(Float32, 10)) # Just to check it runs on weird matrices.
if VERSION > v"1.8-" # on 1.6 this makes a sparse array.
@test dropout(x2, 0.3) isa Matrix{Float32} # does not infer, but that's OK?
end

# Values
@test dropout(x1, 0) == x1
Expand Down Expand Up @@ -76,7 +73,7 @@ using Zygote, StableRNGs, ChainRulesCore, Enzyme
@test_throws ArgumentError dropout!(y1, x1, 3)
end

@static if Test_Enzyme
if NNLIB_TEST_ENZYME

@testset "EnzymeRules: dropout " begin
rng = Random.default_rng()
Expand Down
2 changes: 1 addition & 1 deletion test/pooling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -948,7 +948,7 @@ end
gradtest(x -> sum(meanpool(x, k)), x)
end

@static if Test_Enzyme
if NNLIB_TEST_ENZYME

@testset "EnzymeRules: pooling! $pool spatial_rank=$spatial_rank " for spatial_rank in (1, 2),
(pool, pool!) in ((maxpool, maxpool!), (meanpool, meanpool!))
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ import ReverseDiff as RD # used in `pooling.jl`
import Pkg
using SpecialFunctions

const Test_Enzyme = VERSION <= v"1.10-"

DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true)

const NNLIB_TEST_ENZYME = true
# ENV["NNLIB_TEST_CUDA"] = "true" # uncomment to run CUDA tests
# ENV["NNLIB_TEST_AMDGPU"] = "true" # uncomment to run AMDGPU tests
# ENV["NNLIB_TEST_CPU"] = "false" # uncomment to skip CPU tests
Expand Down
2 changes: 1 addition & 1 deletion test/testsuite/gather.jl
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ function gather_testsuite(Backend)
gradtest_fn((s, i) -> gather(s, i), src, idx)
end

@static if Test_Enzyme
if NNLIB_TEST_ENZYME

@testset "EnzymeRules: gather! gradient for scalar index" begin
src = device(Float64[3, 4, 5, 6, 7])
Expand Down
2 changes: 1 addition & 1 deletion test/testsuite/scatter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ function scatter_testsuite(Backend)
end


@static if Test_Enzyme
if NNLIB_TEST_ENZYME

@testset "EnzymeRules" begin
idx = device([2, 2, 3, 4, 4])
Expand Down
Loading