Skip to content

Commit ae26847

Browse files
authored
Merge pull request #223 from JuliaGPU/tb/setindex_cpu_slice
Use Adapt to support setindex with CPU slices.
2 parents 040bacd + 2690f11 commit ae26847

File tree

3 files changed

+10
-13
lines changed

3 files changed

+10
-13
lines changed

.gitlab-ci.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,14 +44,14 @@ cuarrays:
4444
- sm_75
4545
image: juliagpu/cuda:10.1-cudnn7-cutensor-devel-ubuntu18.04
4646
script:
47-
- export CUARRAYS="$HOME/.julia/dev/CuArrays"
47+
- export CUARRAYS=".julia/dev/CuArrays"
4848
- julia -e 'using Pkg;
4949
Pkg.develop("CuArrays");'
5050
- julia --project -e 'using Pkg;
5151
Pkg.instantiate()'
5252
- julia --project=$CUARRAYS -e 'using Pkg;
5353
Pkg.instantiate();
54-
Pkg.add(["FFTW", "ForwardDiff"])'
54+
Pkg.add(["FFTW", "ForwardDiff", "FillArrays"])'
5555
- JULIA_LOAD_PATH=".:$CUARRAYS::" julia $CUARRAYS/test/runtests.jl
5656
allow_failure: true
5757

src/indexing.jl

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -113,22 +113,13 @@ end
113113
end
114114
end
115115

116-
# FIXME: this should use adapt
117-
gpu_convert(GPUType, x::GPUArray) = x
118-
function gpu_convert(GPUType, x::AbstractArray)
119-
isbits(x) ? x : convert(GPUType, x)
120-
end
121-
function gpu_convert(GPUType, x)
122-
isbits(x) ? x : error("Only isbits types are allowed for indexing. Found: $(typeof(x))")
123-
end
124-
125116
function Base._unsafe_setindex!(::IndexStyle, dest::T, src, Is::Union{Real, AbstractArray}...) where T <: GPUArray
126117
if length(Is) == 1 && isa(first(Is), Array) && isempty(first(Is)) # indexing with empty array
127118
return dest
128119
end
129120
idims = length.(Is)
130121
len = prod(idims)
131-
src_gpu = gpu_convert(T, src)
122+
src_gpu = adapt(T, src)
132123
gpu_call(setindex_kernel!, dest, (dest, src_gpu, idims, map(x-> to_index(dest, x), Is), len), len)
133124
return dest
134125
end

test/testsuite/indexing.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@ function test_indexing(AT)
1818
x[2:6, 2:6, :, :] = y
1919
x[2:6, 2:6, :, :] == y
2020
end
21-
21+
@testset "multi dim, sliced setindex, CPU source" begin
22+
x = fill(AT{T}, T(0), (2,3,4))
23+
y = Array{T}(undef, 2,3)
24+
rand!(y)
25+
x[:, :, 2] = y
26+
x[:, :, 2] == y
27+
end
2228
end
2329

2430
for T in (Float32, Int32)

0 commit comments

Comments
 (0)