Skip to content

Commit 2fff2b0

Browse files
committed
fix: Specialized ReshapedArray dispatch to resolve setindex! ambiguities
1 parent efdc7cf commit 2fff2b0

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed

src/host/indexing.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,8 +167,14 @@ end
167167
function Base._unsafe_setindex!(::IndexStyle, A::WrappedGPUArray, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
168168
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
169169
end
170-
# And allow one more `ReshapedArray` wrapper to handle the `_maybe_reshape` optimization.
171-
function Base._unsafe_setindex!(::IndexStyle, A::Base.ReshapedArray{<:Any, <:Any, <:WrappedGPUArray}, x, Is::Vararg{Union{Real,AbstractArray}, N}) where N
170+
171+
#Implementation for ReshapedArrays using Cartesian indexing to resolve dispatch ties.
172+
function Base._unsafe_setindex!(::Base.IndexCartesian, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, N}) where {T, N}
173+
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
174+
end
175+
176+
#Implementation for ReshapedArrays using Linear indexing to resolve dispatch ties.
177+
function Base._unsafe_setindex!(::Base.IndexLinear, A::Base.ReshapedArray{T, N, <:WrappedGPUArray}, x, Is::Vararg{Union{Real, AbstractArray}, N}) where {T, N}
172178
return vectorized_setindex!(A, x, Base.ensure_indexable(Is)...)
173179
end
174180

test/testsuite/indexing.jl

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,3 +284,78 @@ end
284284
@test compare(argmin, AT, -rand(Int, 10))
285285
end
286286
end
287+
288+
@testsuite "indexing combinatorial" (AT, eltypes) -> begin
289+
@testset "Reshaped SubArray dispatch" for T in eltypes
290+
#3D slice assignment
291+
@testset "3D slice assignment" begin
292+
let A = AT(ones(T, 4, 4, 4))
293+
@views V = A[:, :, 1:2]
294+
@allowscalar begin
295+
@test_nowarn V[:] .= T(0)
296+
@test all(Array(V) .== T(0))
297+
end
298+
end
299+
end
300+
301+
#Logical indexing views
302+
@testset "Logical mask view" begin
303+
let A = AT(ones(T, 4, 4, 4))
304+
mask = [true, false, true, false]
305+
@views V = A[:, :, mask]
306+
@allowscalar begin
307+
@test_nowarn V .+= T(2)
308+
@test all(Array(V) .== T(3))
309+
end
310+
end
311+
end
312+
313+
#Nested Reshape layers
314+
@testset "Nested Reshape" begin
315+
let A = AT(ones(T, 4, 4, 4))
316+
V = view(A, 1:2, 1:2, 1:2)
317+
R1 = reshape(V, (4, 2))
318+
R2 = reshape(R1, :)
319+
@allowscalar begin
320+
@test_nowarn R2 .+= T(1)
321+
@test all(Array(R2) .== T(2))
322+
end
323+
end
324+
end
325+
end
326+
327+
@testset "Permuted and Reinterpreted Views" for T in eltypes
328+
#PermutedDimsArray + Reshape
329+
@testset "Reshaped PermutedDims" begin
330+
let A = AT(ones(T, 4, 4))
331+
P = PermutedDimsArray(A, (2, 1))
332+
R = reshape(P, :)
333+
@allowscalar begin
334+
@test_nowarn R[1:2] .= T(0)
335+
@test Array(R)[1] == T(0)
336+
end
337+
end
338+
end
339+
340+
#Reinterpreted Reshape
341+
@testset "Reshaped Reinterpreted" begin
342+
let A = AT(ones(T, 4, 4))
343+
IT = T <: Complex ? Complex{Int16} : Int16
344+
R = reshape(reinterpret(IT, A), :)
345+
@allowscalar begin
346+
@test_nowarn R[1] = IT(0)
347+
@test Array(R)[1] == IT(0)
348+
end
349+
end
350+
end
351+
end
352+
353+
@testset "Data parity with compare()" for T in eltypes
354+
@test compare(AT, rand(T, 8, 8, 8)) do A
355+
mask = [true, false, true, false, true, false, true, false]
356+
@views V = A[:, mask, :]
357+
@allowscalar V .+= T(1)
358+
A
359+
end
360+
end
361+
end

0 commit comments

Comments
 (0)