Skip to content

Commit 80274cd

Browse files
authored
add PermutedDimsArray method that returns PermutedDiskArray (#249)
* add PermutedDimsArray method that returns PermutedDiskArray * bugfix recursion * A type param * bugfix parent and ConstructionBase
1 parent 294fc9c commit 80274cd

File tree

6 files changed

+78
-49
lines changed

6 files changed

+78
-49
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@ authors = ["Fabian Gans <[email protected]>"]
44
version = "0.4.12"
55

66
[deps]
7+
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
78
LRUCache = "8ac3fa9e-de4c-5943-b1dc-09c6b5f20637"
89
Mmap = "a63ad114-7e13-5084-954f-fe012c677804"
910
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
1011

1112
[compat]
1213
Aqua = "0.8"
14+
ConstructionBase = "1"
1315
LRUCache = "1"
1416
Mmap = "1"
1517
OffsetArrays = "1"

src/DiskArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
module DiskArrays
22

3+
import ConstructionBase
4+
35
using LRUCache: LRUCache, LRU
46

57
# Use the README as the module docs

src/permute.jl

Lines changed: 36 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,41 @@
33
44
A lazily permuted disk array returned by `permutedims(diskarray, permutation)`.
55
"""
6-
struct PermutedDiskArray{T,N,P<:PermutedDimsArray{T,N}} <: AbstractDiskArray{T,N}
7-
a::P
6+
struct PermutedDiskArray{T,N,perm,iperm,A<:AbstractArray{T,N}} <: AbstractDiskArray{T,N}
7+
parent::A
88
end
9+
# We use PermutedDimsArray internals instead of duplicating them,
10+
# and just copy the type parameters it calculates.
11+
PermutedDiskArray(A::AbstractArray, perm::Union{Tuple,AbstractVector}) =
12+
PermutedDiskArray(A, PermutedDimsArray(CartesianIndices(A), perm))
13+
function PermutedDiskArray(
14+
a::A, ::PermutedDimsArray{<:Any,<:Any,perm,iperm}
15+
) where {A<:AbstractArray{T,N},perm,iperm} where {T,N}
16+
PermutedDiskArray{T,N,perm,iperm,A}(a)
17+
end
18+
19+
# We need explicit ConstructionBase support as perm and iperm are only in the type.
20+
# We include N so that only arrays of the same dimensionality can be set with this perm and iperm
21+
struct PermutedDiskArrayConstructor{N,perm,iperm} end
22+
23+
(::PermutedDiskArrayConstructor{N,perm,iperm})(a::A) where A<:AbstractArray{T,N} where {T,N,perm,iperm} =
24+
PermutedDiskArray{T,N,perm,iperm,A}(a)
25+
26+
ConstructionBase.constructorof(::Type{<:PermutedDiskArray{<:Any,N,perm,iperm}}) where {N,perm,iperm} =
27+
PermutedDiskArrayConstructor{N,perm,iperm}()
928

1029
# Base methods
1130

12-
Base.size(a::PermutedDiskArray) = size(a.a)
31+
Base.parent(a::PermutedDiskArray) = a.parent
32+
Base.size(a::PermutedDiskArray) = genperm(size(parent(a)), _getperm(a))
1333

1434
# DiskArrays interface
1535

16-
haschunks(a::PermutedDiskArray) = haschunks(a.a.parent)
36+
haschunks(a::PermutedDiskArray) = haschunks(parent(a))
1737
function eachchunk(a::PermutedDiskArray)
1838
# Get the parent chunks
19-
gridchunks = eachchunk(a.a.parent)
20-
perm = _getperm(a.a)
39+
gridchunks = eachchunk(parent(a))
40+
perm = _getperm(a)
2141
# Return permuted GridChunks
2242
return GridChunks(genperm(gridchunks.chunks, perm)...)
2343
end
@@ -26,33 +46,29 @@ function DiskArrays.readblock!(a::PermutedDiskArray, aout, i::OrdinalRange...)
2646
# Permute the indices
2747
inew = genperm(i, iperm)
2848
# Permute the dest block and read from the true parent
29-
DiskArrays.readblock!(a.a.parent, PermutedDimsArray(aout, iperm), inew...)
49+
DiskArrays.readblock!(parent(a), PermutedDimsArray(aout, iperm), inew...)
3050
return nothing
3151
end
3252
function DiskArrays.writeblock!(a::PermutedDiskArray, v, i::OrdinalRange...)
3353
iperm = _getiperm(a)
3454
inew = genperm(i, iperm)
3555
# Permute the dest block and write from the true parent
36-
DiskArrays.writeblock!(a.a.parent, PermutedDimsArray(v, iperm), inew...)
56+
DiskArrays.writeblock!(parent(a), PermutedDimsArray(v, iperm), inew...)
3757
return nothing
3858
end
3959

40-
_getperm(a::PermutedDiskArray) = _getperm(a.a)
41-
_getperm(::PermutedDimsArray{<:Any,<:Any,perm}) where {perm} = perm
60+
_getperm(::PermutedDiskArray{<:Any,<:Any,perm}) where {perm} = perm
61+
_getiperm(::PermutedDiskArray{<:Any,<:Any,<:Any,iperm}) where {iperm} = iperm
4262

43-
_getiperm(a::PermutedDiskArray) = _getiperm(a.a)
44-
_getiperm(::PermutedDimsArray{<:Any,<:Any,<:Any,iperm}) where {iperm} = iperm
45-
46-
# Implementaion macros
47-
48-
function permutedims_disk(a, perm)
49-
pd = PermutedDimsArray(a, perm)
50-
return PermutedDiskArray{eltype(a),ndims(a),typeof(pd)}(pd)
51-
end
63+
# Implementation macro
5264

5365
macro implement_permutedims(t)
5466
t = esc(t)
5567
quote
56-
Base.permutedims(parent::$t, perm) = permutedims_disk(parent, perm)
68+
Base.permutedims(parent::$t, perm) = PermutedDiskArray(parent, perm)
69+
# It's not correct to return a PermutedDiskArray from the PermutedDimsArray constructor.
70+
# Instead we need a Base julia method that behaves like view for SubArray, such as `lazypermutedims`.
71+
# But until that exists this is better than returning a broken disk array.
72+
Base.PermutedDimsArray(parent::$t, perm) = PermutedDiskArray(parent, perm)
5773
end
5874
end

src/reshape.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ end
2323

2424
# Base methods
2525

26+
Base.parent(r::ReshapedDiskArray) = r.parent
2627
Base.size(r::ReshapedDiskArray) = r.newsize
2728

2829
# DiskArrays interface

src/util/testtypes.jl

Lines changed: 27 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@ DiskArrays.batchstrategy(a::AccessCountDiskArray) = a.batchstrategy
2525
AccessCountDiskArray(a; chunksize=size(a), batchstrategy=DiskArrays.ChunkRead(DiskArrays.NoStepRange(), 0.5)) =
2626
AccessCountDiskArray([], [], a, chunksize, batchstrategy)
2727

28-
Base.size(a::AccessCountDiskArray) = size(a.parent)
28+
Base.parent(a::AccessCountDiskArray) = a.parent
29+
Base.size(a::AccessCountDiskArray) = size(parent(a))
2930

3031
# Apply the all in one macro rather than inheriting
3132

@@ -38,7 +39,7 @@ function DiskArrays.readblock!(a::AccessCountDiskArray, aout, i::OrdinalRange...
3839
end
3940
# println("reading from indices ", join(string.(i)," "))
4041
push!(a.getindex_log, i)
41-
return aout .= a.parent[i...]
42+
return aout .= parent(a)[i...]
4243
end
4344
function DiskArrays.writeblock!(a::AccessCountDiskArray, v, i::OrdinalRange...)
4445
ndims(a) == length(i) || error("Number of indices is not correct")
@@ -47,31 +48,30 @@ function DiskArrays.writeblock!(a::AccessCountDiskArray, v, i::OrdinalRange...)
4748
end
4849
# println("Writing to indices ", join(string.(i)," "))
4950
push!(a.setindex_log, i)
50-
return view(a.parent, i...) .= v
51+
return view(parent(a), i...) .= v
5152
end
5253

5354
getindex_count(a::AccessCountDiskArray) = length(a.getindex_log)
5455
setindex_count(a::AccessCountDiskArray) = length(a.setindex_log)
5556
getindex_log(a::AccessCountDiskArray) = a.getindex_log
5657
setindex_log(a::AccessCountDiskArray) = a.setindex_log
57-
trueparent(a::AccessCountDiskArray) = a.parent
58-
59-
getindex_count(a::DiskArrays.ReshapedDiskArray) = getindex_count(a.parent)
60-
setindex_count(a::DiskArrays.ReshapedDiskArray) = setindex_count(a.parent)
61-
getindex_log(a::DiskArrays.ReshapedDiskArray) = getindex_log(a.parent)
62-
setindex_log(a::DiskArrays.ReshapedDiskArray) = setindex_log(a.parent)
63-
trueparent(a::DiskArrays.ReshapedDiskArray) = trueparent(a.parent)
64-
65-
getindex_count(a::DiskArrays.PermutedDiskArray) = getindex_count(a.a.parent)
66-
setindex_count(a::DiskArrays.PermutedDiskArray) = setindex_count(a.a.parent)
67-
getindex_log(a::DiskArrays.PermutedDiskArray) = getindex_log(a.a.parent)
68-
setindex_log(a::DiskArrays.PermutedDiskArray) = setindex_log(a.a.parent)
69-
function trueparent(
70-
a::DiskArrays.PermutedDiskArray{T,N,<:PermutedDimsArray{T,N,perm,iperm}}
71-
) where {T,N,perm,iperm}
72-
return permutedims(trueparent(a.a.parent), perm)
58+
trueparent(a::AccessCountDiskArray) = parent(a)
59+
60+
getindex_count(a::DiskArrays.AbstractDiskArray) = getindex_count(parent(a))
61+
setindex_count(a::DiskArrays.AbstractDiskArray) = setindex_count(parent(a))
62+
getindex_log(a::DiskArrays.AbstractDiskArray) = getindex_log(parent(a))
63+
setindex_log(a::DiskArrays.AbstractDiskArray) = setindex_log(parent(a))
64+
function trueparent(a::DiskArrays.AbstractDiskArray)
65+
if parent(a) === a
66+
a
67+
else
68+
trueparent(parent(a))
69+
end
7370
end
7471

72+
trueparent(a::DiskArrays.PermutedDiskArray{T,N,perm,iperm}) where {T,N,perm,iperm} =
73+
permutedims(trueparent(parent(a)), perm)
74+
7575
"""
7676
ChunkedDiskArray(A; chunksize)
7777
@@ -83,29 +83,31 @@ struct ChunkedDiskArray{T,N,A<:AbstractArray{T,N}} <: DiskArrays.AbstractDiskArr
8383
end
8484
ChunkedDiskArray(a; chunksize=size(a)) = ChunkedDiskArray(a, chunksize)
8585

86-
Base.size(a::ChunkedDiskArray) = size(a.parent)
86+
Base.parent(a::ChunkedDiskArray) = a.parent
87+
Base.size(a::ChunkedDiskArray) = size(parent(a))
8788

8889
DiskArrays.haschunks(::ChunkedDiskArray) = DiskArrays.Chunked()
8990
DiskArrays.eachchunk(a::ChunkedDiskArray) = DiskArrays.GridChunks(a, a.chunksize)
90-
DiskArrays.readblock!(a::ChunkedDiskArray, aout, i::AbstractUnitRange...) = aout .= a.parent[i...]
91-
DiskArrays.writeblock!(a::ChunkedDiskArray, v, i::AbstractUnitRange...) = view(a.parent, i...) .= v
91+
DiskArrays.readblock!(a::ChunkedDiskArray, aout, i::AbstractUnitRange...) = aout .= parent(a)[i...]
92+
DiskArrays.writeblock!(a::ChunkedDiskArray, v, i::AbstractUnitRange...) = view(parent(a), i...) .= v
9293

9394
"""
9495
UnchunkedDiskArray(A)
9596
9697
A disk array without chunking, that can wrap any other `AbstractArray`.
9798
"""
9899
struct UnchunkedDiskArray{T,N,P<:AbstractArray{T,N}} <: DiskArrays.AbstractDiskArray{T,N}
99-
p::P
100+
parent::P
100101
end
101102

102-
Base.size(a::UnchunkedDiskArray) = size(a.p)
103+
Base.parent(a::UnchunkedDiskArray) = a.parent
104+
Base.size(a::UnchunkedDiskArray) = size(parent(a))
103105

104106
DiskArrays.haschunks(::UnchunkedDiskArray) = DiskArrays.Unchunked()
105107
function DiskArrays.readblock!(a::UnchunkedDiskArray, aout, i::AbstractUnitRange...)
106108
ndims(a) == length(i) || error("Number of indices is not correct")
107109
all(r -> isa(r, AbstractUnitRange), i) || error("Not all indices are unit ranges")
108-
return aout .= a.p[i...]
110+
return aout .= parent(a)[i...]
109111
end
110112

111113
end

test/runtests.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ using DiskArrays.TestTypes
44
using Test
55
using Statistics
66
using Aqua
7+
using ConstructionBase
78
using TraceFuns, Suppressor
89

910
# Run with any code changes
@@ -818,15 +819,20 @@ import Base.PermutedDimsArrays.invperm
818819
ip = invperm(p)
819820
a = permutedims(AccessCountDiskArray(permutedims(reshape(1:20, 4, 5, 1), ip)), p)
820821
test_getindex(a)
821-
a = permutedims(AccessCountDiskArray(zeros(Int, 5, 1, 4)), p)
822+
a = PermutedDimsArray(AccessCountDiskArray(zeros(Int, 5, 1, 4)), p)
822823
test_setindex(a)
823824
a = permutedims(AccessCountDiskArray(zeros(Int, 5, 1, 4)), p)
824825
test_view(a)
825-
a = data -> permutedims(AccessCountDiskArray(permutedims(data, ip); chunksize=(4, 2, 5)), p)
826-
test_reductions(a)
826+
f = data -> permutedims(AccessCountDiskArray(permutedims(data, ip); chunksize=(4, 2, 5)), p)
827+
test_reductions(f)
827828
a_disk1 = permutedims(AccessCountDiskArray(rand(9, 2, 10); chunksize=(3, 2, 5)), p)
828829
test_broadcast(a_disk1)
829-
@test PermutedDiskArray(a_disk1.a) === a_disk1
830+
831+
@testset "ConstructionBase works on PermutedDiskArray" begin
832+
v = ones(Int, 10, 2, 2)
833+
av = ConstructionBase.setproperties(a, (; parent=v))
834+
@test parent(av) === v
835+
end
830836
end
831837

832838
@testset "Unchunked String arrays" begin

0 commit comments

Comments
 (0)