Skip to content

Commit ca051c8

Browse files
committed
Use Adapt to support setindex with CPU slices.
1 parent 040bacd commit ca051c8

File tree

2 files changed

+8
-11
lines changed

2 files changed

+8
-11
lines changed

src/indexing.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,22 +113,13 @@ end
113113
end
114114
end
115115

116-
# FIXME: this should use adapt
117-
gpu_convert(GPUType, x::GPUArray) = x
118-
function gpu_convert(GPUType, x::AbstractArray)
119-
isbits(x) ? x : convert(GPUType, x)
120-
end
121-
function gpu_convert(GPUType, x)
122-
isbits(x) ? x : error("Only isbits types are allowed for indexing. Found: $(typeof(x))")
123-
end
124-
125116
function Base._unsafe_setindex!(::IndexStyle, dest::T, src, Is::Union{Real, AbstractArray}...) where T <: GPUArray
126117
if length(Is) == 1 && isa(first(Is), Array) && isempty(first(Is)) # indexing with empty array
127118
return dest
128119
end
129120
idims = length.(Is)
130121
len = prod(idims)
131-
src_gpu = gpu_convert(T, src)
122+
src_gpu = adapt(T, src)
132123
gpu_call(setindex_kernel!, dest, (dest, src_gpu, idims, map(x-> to_index(dest, x), Is), len), len)
133124
return dest
134125
end

test/testsuite/indexing.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@ function test_indexing(AT)
1818
x[2:6, 2:6, :, :] = y
1919
x[2:6, 2:6, :, :] == y
2020
end
21-
21+
@testset "multi dim, sliced setindex, CPU source" begin
22+
x = fill(AT{T}, T(0), (2,3,4))
23+
y = Array{T}(undef, 2,3)
24+
rand!(y)
25+
x[:, :, 2] = y
26+
x[:, :, 2] == y
27+
end
2228
end
2329

2430
for T in (Float32, Int32)

0 commit comments

Comments
 (0)