Skip to content

Commit 46a3383

Browse files
committed
Add PermuteDimsArray support.
1 parent 7c56448 commit 46a3383

File tree

2 files changed

+17
-0
lines changed

2 files changed

+17
-0
lines changed

src/host/base.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,13 @@ function Base.repeat(a::AbstractGPUVector, m::Int)
5353
end
5454
return b
5555
end
56+
57+
## PermutedDimsArrays
58+
59+
using Base: PermutedDimsArrays
60+
61+
# PermutedDimsArrays' custom copyto! doesn't know how to deal with GPU arrays
62+
function PermutedDimsArrays._copy!(dest::PermutedDimsArray{T,N,<:Any,<:Any,<:AbstractGPUArray}, src) where {T,N}
63+
dest .= src
64+
dest
65+
end

test/testsuite/base.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,5 +148,12 @@ function test_base(AT)
148148
@test blocks == 1
149149
@test threads == 1
150150
end
151+
152+
@testset "permutedims" begin
153+
@test compare(x->permutedims(x, [1, 2]), AT, rand(4, 4))
154+
155+
inds = rand(1:100, 150, 150)
156+
@test compare(x->permutedims(view(x, inds, :), (3, 2, 1)), AT, rand(100, 100))
157+
end
151158
end
152159
end

0 commit comments

Comments
 (0)