Skip to content

Commit 8c119cf

Browse files
Fix global linear indexing (fill!) (#496)
1 parent 6ecb909 commit 8c119cf

File tree

2 files changed

+7
-11
lines changed

2 files changed

+7
-11
lines changed

src/MetalKernels.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,9 @@ end
140140
end
141141

142142
@device_override @inline function KA.__index_Global_Linear(ctx)
143-
return thread_position_in_grid_1d()
143+
I = @inbounds KA.expand(KA.__iterspace(ctx), threadgroup_position_in_grid_1d(), thread_position_in_threadgroup_1d())
144+
# TODO: This is unfortunate, can we get the linear index cheaper
145+
@inbounds LinearIndices(KA.__ndrange(ctx))[I]
144146
end
145147

146148
@device_override @inline function KA.__index_Local_Cartesian(ctx)

test/array.jl

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -229,15 +229,12 @@ end
229229

230230
@testset "fill($T)" for T in [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64,
231231
Float16, Float32]
232-
broken466a = T [Int8,UInt8]
233-
broken466b = (Base.JLOptions().check_bounds != 1 || shader_validation)
234-
235232
b = rand(T)
236233

237234
# Dims in tuple
238235
let A = Metal.fill(b, (10, 10, 10, 1000))
239236
B = fill(b, (10, 10, 10, 1000))
240-
@test Array(A) == B broken=(broken466a && broken466b)
237+
@test Array(A) == B
241238
end
242239

243240
let M = Metal.fill(b, (10, 10))
@@ -253,7 +250,7 @@ end
253250
#Dims already unpacked
254251
let A = Metal.fill(b, 10, 1000, 1000)
255252
B = fill(b, 10, 1000, 1000)
256-
@test Array(A) == B broken=broken466a
253+
@test Array(A) == B
257254
end
258255

259256
let M = Metal.fill(b, 10, 10)
@@ -269,15 +266,12 @@ end
269266

270267
@testset "fill!($T)" for T in [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64,
271268
Float16, Float32]
272-
broken466a = T [Int8,UInt8]
273-
broken466b = (Base.JLOptions().check_bounds != 1 || shader_validation)
274-
275269
b = rand(T)
276270

277271
# Dims in tuple
278272
let A = MtlArray{T,3}(undef, (10, 1000, 1000))
279273
fill!(A, b)
280-
@test all(Array(A) .== b) broken=broken466a
274+
@test all(Array(A) .== b)
281275
end
282276

283277
let M = MtlMatrix{T}(undef, (10, 10))
@@ -293,7 +287,7 @@ end
293287
# Dims already unpacked
294288
let A = MtlArray{T,4}(undef, 10, 10, 10, 1000)
295289
fill!(A, b)
296-
@test all(Array(A) .== b) broken=(broken466a && broken466b)
290+
@test all(Array(A) .== b)
297291
end
298292

299293
let M = MtlMatrix{T}(undef, 10, 10)

0 commit comments

Comments
 (0)