@@ -69,37 +69,36 @@ res = Dict(
69
69
)
70
70
71
71
function test_scatter (device, types, ops; pt, ops_skip_types)
72
- for T in types
72
+ for T in types, IT in (Int8, Int64)
73
73
PT = promote_type (T, pt)
74
- @testset " $T " begin
75
- for op in ops
76
- skip_types = get (ops_skip_types, op, [])
77
- @testset " $op " begin
78
- for idx = values (idxs), dims = [0 , 1 ]
79
- idx = device (idx)
80
- dst = device (dsts[dims])
81
-
82
- mutated = true
83
- target_y = res[(op, dims, mutated)]
84
- src = device (srcs[(dims, mutated)])
85
- if op == /
86
- src = src .* T (2 )
87
- end
88
-
89
- @test cpu (scatter! (op, T .(dst), T .(src), idx)) == T .(target_y)
90
- @test cpu (scatter! (op, T .(dst), src, idx)) == PT .(target_y)
91
- if op == /
92
- @test cpu (scatter! (op, T .(dst), T .(src), idx)) == PT .(target_y)
93
- else
94
- @test cpu (scatter! (op, copy (dst), T .(src), idx)) == PT .(target_y)
95
- end
96
-
97
- if T ∉ skip_types
98
- mutated = false
99
- src = device (srcs[(dims, mutated)])
100
- @test cpu (scatter (op, T .(src), idx)) == T .(res[(op, dims, mutated)])
101
- end
102
- end
74
+ @testset " eltype $T - idx eltype $IT - $op " for op in ops
75
+ skip_types = get (ops_skip_types, op, [])
76
+ for idx = values (idxs), dims = [0 , 1 ]
77
+ # Tests with indices of different types.
78
+ eltype (idx) == Int && (idx = IT .(idx);)
79
+
80
+ idx = device (idx)
81
+ dst = device (dsts[dims])
82
+
83
+ mutated = true
84
+ target_y = res[(op, dims, mutated)]
85
+ src = device (srcs[(dims, mutated)])
86
+ if op == /
87
+ src = src .* T (2 )
88
+ end
89
+
90
+ @test cpu (scatter! (op, T .(dst), T .(src), idx)) == T .(target_y)
91
+ @test cpu (scatter! (op, T .(dst), src, idx)) == PT .(target_y)
92
+ if op == /
93
+ @test cpu (scatter! (op, T .(dst), T .(src), idx)) == PT .(target_y)
94
+ else
95
+ @test cpu (scatter! (op, copy (dst), T .(src), idx)) == PT .(target_y)
96
+ end
97
+
98
+ if T ∉ skip_types
99
+ mutated = false
100
+ src = device (srcs[(dims, mutated)])
101
+ @test cpu (scatter (op, T .(src), idx)) == T .(res[(op, dims, mutated)])
103
102
end
104
103
end
105
104
end
@@ -174,14 +173,14 @@ function scatter_testsuite(Backend)
174
173
else
175
174
(+ , - , mean, max, min)
176
175
end
177
- for op in ops, i in (0 , 1 )
176
+ for op in ops, i in (0 , 1 ), IT in (Int8, Int64)
178
177
PT = ( # If not CPU and CUDA -> use Int64 for min/max.
179
178
Backend != CPU &&
180
179
Symbol (Backend) != :CUDABackend &&
181
180
(op == max || op == min)) ? Int64 : T
182
181
183
182
src = device (srcs[(i, true )])
184
- idx = device (idxs[:int ])
183
+ idx = device (IT .( idxs[:int ]) )
185
184
dst = device (PT .(dsts[i]))
186
185
Backend == CPU ?
187
186
gradtest_fn (x -> scatter! (op, copy (x), src, idx), dst; fdm= fdm (op)) :
@@ -195,19 +194,20 @@ function scatter_testsuite(Backend)
195
194
else
196
195
(+ , - , mean, max, min)
197
196
end
198
- for op in ops, i in (0 , 1 )
197
+ for op in ops, i in (0 , 1 ), IT in (Int8, Int64)
199
198
PT = ( # If not CPU and CUDA -> use Int64 for min/max.
200
199
Backend != CPU &&
201
200
Symbol (Backend) != :CUDABackend &&
202
201
(op == max || op == min)) ? Int64 : T
203
202
src = PT .(device (srcs[(i, false )]))
204
- idx = device (idxs[:int ])
203
+ idx = device (IT .( idxs[:int ]) )
205
204
Backend == CPU ?
206
205
gradtest_fn (xs -> scatter (op, xs, idx), src; fdm= fdm (op)) :
207
206
gradtest_fn ((xs, i) -> scatter (op, xs, i), src, idx)
208
207
end
209
208
end
210
209
210
+
211
211
@static if Test_Enzyme
212
212
213
213
@testset " EnzymeRules" begin
0 commit comments