Skip to content

Commit 3e4a0a1

Browse files
committed
Fix tests
1 parent e80370e commit 3e4a0a1

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

src/abstractblocksparsearray/unblockedsubarray.jl

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,23 @@ function Base.map!(
6464
a_src_rest...,
6565
)
6666
end
67+
68+
# Fix ambiguity and scalar indexing errors with GPUArrays.
69+
using Adapt: adapt
70+
using GPUArraysCore: GPUArraysCore
71+
function Base.map!(
72+
f,
73+
a_dest::GPUArraysCore.AnyGPUArray,
74+
a_src1::UnblockedSubArray,
75+
a_src_rest::UnblockedSubArray...,
76+
)
77+
a_dest_cpu = adapt(Array, a_dest)
78+
a_srcs_cpu = map(adapt(Array), (a_src1, a_src_rest...))
79+
map!(f, a_dest_cpu, a_srcs_cpu...)
80+
a_dest .= a_dest_cpu
81+
return a_dest
82+
end
83+
6784
function Base.iszero(a::UnblockedSubArray)
68-
return invoke(iszero, Tuple{AbstractArray}, a)
85+
return invoke(iszero, Tuple{AbstractArray}, adapt(Array, a))
6986
end

0 commit comments

Comments
 (0)