Skip to content

Commit 607de4b

Browse files
authored
Allow non-Int64 indices in scatter (#543)
* Fix typo in version comparison * Allow non-Int64 indices in scatter * Disable Enzyme for AMDGPU * Refactor
1 parent af0aa2c commit 607de4b

File tree

3 files changed

+46
-38
lines changed

3 files changed

+46
-38
lines changed

src/scatter.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ end
108108

109109
@kernel function _scatter!(op::OP, dst, src, idxs) where OP
110110
i = @index(Global)
111-
@inbounds idx = Tuple(idxs[i])
111+
@inbounds idx = Tuple(_convert_i64(idxs[i]))
112112
@inbounds Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i])
113113
# FIXME `@atomic` macro silently fails to perform atomic op below
114114
# @atomic dst[idx...] = op(dst[idx...], src[i])
@@ -119,14 +119,20 @@ end
119119
) where OP
120120
i = @index(Global)
121121
j, k = divrem(i - 1, max_dims_idx)
122-
@inbounds idx = (Tuple(dim_ids[k + 1])..., Tuple(idxs[j + 1])...)
122+
@inbounds idx = (Tuple(dim_ids[k + 1])..., Tuple(_convert_i64(idxs[j + 1]))...)
123123
@inbounds Atomix.modify!(Atomix.IndexableRef(dst, idx), op, src[i])
124-
# FIXME
124+
# FIXME `@atomic` macro silently fails to perform atomic op below
125125
# dim_i = Tuple(dim_ids[k + 1])
126126
# idx = idxs[j + 1]
127127
# @atomic dst[dim_i..., idx...] = op(dst[dim_i..., idx...], src[i])
128128
end
129129

130+
# Allow non-Int64 indices by converting them to Int64 when index eltype <: Integer.
131+
# All other index types (tuples, cartesian indices) must be in Int64 already.
132+
@inline _convert_i64(x::Int) = x
133+
@inline _convert_i64(x::Integer) = Int(x)
134+
@inline _convert_i64(x) = x
135+
130136
"""
131137
NNlib.scatter(op, src, idx; [init, dstsize])
132138

test/runtests.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ using Adapt
1313
using KernelAbstractions
1414
import ReverseDiff as RD # used in `pooling.jl`
1515

16-
const Test_Enzyme = VERSION <= v"1.10" && !Sys.iswindows()
16+
const Test_Enzyme = VERSION <= v"1.10-" && !Sys.iswindows() &&
17+
# TODO Enzyme is not working properly with AMDGPU yet.
18+
get(ENV, "NNLIB_TEST_AMDGPU", "false") != "true"
1719

1820
DocMeta.setdocmeta!(NNlib, :DocTestSetup, :(using NNlib, UnicodePlots); recursive=true)
1921

test/scatter.jl

Lines changed: 34 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -69,37 +69,36 @@ res = Dict(
6969
)
7070

7171
function test_scatter(device, types, ops; pt, ops_skip_types)
72-
for T in types
72+
for T in types, IT in (Int8, Int64)
7373
PT = promote_type(T, pt)
74-
@testset "$T" begin
75-
for op in ops
76-
skip_types = get(ops_skip_types, op, [])
77-
@testset "$op" begin
78-
for idx = values(idxs), dims = [0, 1]
79-
idx = device(idx)
80-
dst = device(dsts[dims])
81-
82-
mutated = true
83-
target_y = res[(op, dims, mutated)]
84-
src = device(srcs[(dims, mutated)])
85-
if op == /
86-
src = src .* T(2)
87-
end
88-
89-
@test cpu(scatter!(op, T.(dst), T.(src), idx)) == T.(target_y)
90-
@test cpu(scatter!(op, T.(dst), src, idx)) == PT.(target_y)
91-
if op == /
92-
@test cpu(scatter!(op, T.(dst), T.(src), idx)) == PT.(target_y)
93-
else
94-
@test cpu(scatter!(op, copy(dst), T.(src), idx)) == PT.(target_y)
95-
end
96-
97-
if T skip_types
98-
mutated = false
99-
src = device(srcs[(dims, mutated)])
100-
@test cpu(scatter(op, T.(src), idx)) == T.(res[(op, dims, mutated)])
101-
end
102-
end
74+
@testset "eltype $T - idx eltype $IT - $op" for op in ops
75+
skip_types = get(ops_skip_types, op, [])
76+
for idx = values(idxs), dims = [0, 1]
77+
# Tests with indices of different types.
78+
eltype(idx) == Int && (idx = IT.(idx);)
79+
80+
idx = device(idx)
81+
dst = device(dsts[dims])
82+
83+
mutated = true
84+
target_y = res[(op, dims, mutated)]
85+
src = device(srcs[(dims, mutated)])
86+
if op == /
87+
src = src .* T(2)
88+
end
89+
90+
@test cpu(scatter!(op, T.(dst), T.(src), idx)) == T.(target_y)
91+
@test cpu(scatter!(op, T.(dst), src, idx)) == PT.(target_y)
92+
if op == /
93+
@test cpu(scatter!(op, T.(dst), T.(src), idx)) == PT.(target_y)
94+
else
95+
@test cpu(scatter!(op, copy(dst), T.(src), idx)) == PT.(target_y)
96+
end
97+
98+
if T skip_types
99+
mutated = false
100+
src = device(srcs[(dims, mutated)])
101+
@test cpu(scatter(op, T.(src), idx)) == T.(res[(op, dims, mutated)])
103102
end
104103
end
105104
end
@@ -174,14 +173,14 @@ function scatter_testsuite(Backend)
174173
else
175174
(+, -, mean, max, min)
176175
end
177-
for op in ops, i in (0, 1)
176+
for op in ops, i in (0, 1), IT in (Int8, Int64)
178177
PT = ( # If not CPU and CUDA -> use Int64 for min/max.
179178
Backend != CPU &&
180179
Symbol(Backend) != :CUDABackend &&
181180
(op == max || op == min)) ? Int64 : T
182181

183182
src = device(srcs[(i, true)])
184-
idx = device(idxs[:int])
183+
idx = device(IT.(idxs[:int]))
185184
dst = device(PT.(dsts[i]))
186185
Backend == CPU ?
187186
gradtest_fn(x -> scatter!(op, copy(x), src, idx), dst; fdm=fdm(op)) :
@@ -195,19 +194,20 @@ function scatter_testsuite(Backend)
195194
else
196195
(+, -, mean, max, min)
197196
end
198-
for op in ops, i in (0, 1)
197+
for op in ops, i in (0, 1), IT in (Int8, Int64)
199198
PT = ( # If not CPU and CUDA -> use Int64 for min/max.
200199
Backend != CPU &&
201200
Symbol(Backend) != :CUDABackend &&
202201
(op == max || op == min)) ? Int64 : T
203202
src = PT.(device(srcs[(i, false)]))
204-
idx = device(idxs[:int])
203+
idx = device(IT.(idxs[:int]))
205204
Backend == CPU ?
206205
gradtest_fn(xs -> scatter(op, xs, idx), src; fdm=fdm(op)) :
207206
gradtest_fn((xs, i) -> scatter(op, xs, i), src, idx)
208207
end
209208
end
210209

210+
211211
@static if Test_Enzyme
212212

213213
@testset "EnzymeRules" begin

0 commit comments

Comments
 (0)