|
| 1 | +function cartesianindex_reverse(I::CartesianIndex) |
| 2 | + return CartesianIndex(reverse(Tuple(I))) |
| 3 | +end |
| 4 | +tuple_oneto(n) = ntuple(identity, n) |
| 5 | + |
| 6 | +# TODO: Use `Base.PermutedDimsArrays.genperm` or |
| 7 | +# https://github.com/jipolanco/StaticPermutations.jl? |
| 8 | +genperm(v, perm) = map(j -> v[j], perm) |
| 9 | + |
| 10 | +## TODO: Use this and something similar for `Dictionary` to make a faster |
| 11 | +## implementation of `storedvalues(::SubArray)`. |
| 12 | +## function valuesview(d::Dict, keys) |
| 13 | +## return @view d.vals[[Base.ht_keyindex(d, key) for key in keys]] |
| 14 | +## end |
| 15 | + |
| 16 | +function eachstoredparentindex(a::SubArray) |
| 17 | + return filter(eachstoredindex(parent(a))) do I |
| 18 | + return all(d -> I[d] ∈ parentindices(a)[d], 1:ndims(parent(a))) |
| 19 | + end |
| 20 | +end |
| 21 | +function storedvalues(a::SubArray) |
| 22 | + return @view parent(a)[collect(eachstoredparentindex(a))] |
| 23 | +end |
| 24 | +function isstored(a::SubArray, I::Int...) |
| 25 | + return isstored(parent(a), Base.reindex(parentindices(a), I)...) |
| 26 | +end |
| 27 | +function getstoredindex(a::SubArray, I::Int...) |
| 28 | + return getstoredindex(parent(a), Base.reindex(parentindices(a), I)...) |
| 29 | +end |
| 30 | +function getunstoredindex(a::SubArray, I::Int...) |
| 31 | + return getunstoredindex(parent(a), Base.reindex(parentindices(a), I)...) |
| 32 | +end |
| 33 | +function eachstoredindex(a::SubArray) |
| 34 | + nonscalardims = filter(tuple_oneto(ndims(parent(a)))) do d |
| 35 | + return !(parentindices(a)[d] isa Real) |
| 36 | + end |
| 37 | + return collect(( |
| 38 | + CartesianIndex( |
| 39 | + map(nonscalardims) do d |
| 40 | + return findfirst(==(I[d]), parentindices(a)[d]) |
| 41 | + end, |
| 42 | + ) for I in eachstoredparentindex(a) |
| 43 | + )) |
| 44 | +end |
| 45 | + |
| 46 | +perm(::PermutedDimsArray{<:Any,<:Any,p}) where {p} = p |
| 47 | +iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,ip}) where {ip} = ip |
| 48 | + |
1 | 49 | storedvalues(a::PermutedDimsArray) = storedvalues(parent(a)) |
2 | 50 | function isstored(a::PermutedDimsArray, I::Int...) |
3 | | - return isstored(parent(a), reverse(I)...) |
4 | | -end |
5 | | -function eachstoredindex(a::PermutedDimsArray) |
6 | | - # TODO: Make lazy with `Iterators.map`. |
7 | | - return map(CartesianIndex ∘ reverse ∘ Tuple, collect(eachstoredindex(parent(a)))) |
| 51 | + return isstored(parent(a), genperm(I, iperm(a))...) |
8 | 52 | end |
9 | 53 | function getstoredindex(a::PermutedDimsArray, I::Int...) |
10 | | - return getstoredindex(parent(a), reverse(I)...) |
| 54 | + return getstoredindex(parent(a), genperm(I, iperm(a))...) |
11 | 55 | end |
12 | 56 | function getunstoredindex(a::PermutedDimsArray, I::Int...) |
13 | | - return getunstoredindex(parent(a), reverse(I)...) |
| 57 | + return getunstoredindex(parent(a), genperm(I, iperm(a))...) |
14 | 58 | end |
15 | 59 | function setstoredindex!(a::PermutedDimsArray, value, I::Int...) |
16 | | - setstoredindex!(parent(a), value, reverse(I)...) |
| 60 | + # TODO: Should this be `iperm(a)`? |
| 61 | + setstoredindex!(parent(a), value, genperm(I, perm(a))...) |
17 | 62 | return a |
18 | 63 | end |
19 | 64 | function setunstoredindex!(a::PermutedDimsArray, value, I::Int...) |
20 | | - setunstoredindex!(parent(a), value, reverse(I)...) |
| 65 | + # TODO: Should this be `iperm(a)`? |
| 66 | + setunstoredindex!(parent(a), value, genperm(I, perm(a))...) |
21 | 67 | return a |
22 | 68 | end |
23 | | - |
24 | | -using LinearAlgebra: Adjoint |
25 | | -storedvalues(a::Adjoint) = storedvalues(parent(a)) |
26 | | -function isstored(a::Adjoint, i::Int, j::Int) |
27 | | - return isstored(parent(a), j, i) |
28 | | -end |
29 | | -function eachstoredindex(a::Adjoint) |
| 69 | +function eachstoredindex(a::PermutedDimsArray) |
30 | 70 | # TODO: Make lazy with `Iterators.map`. |
31 | | - return map(CartesianIndex ∘ reverse ∘ Tuple, collect(eachstoredindex(parent(a)))) |
32 | | -end |
33 | | -function getstoredindex(a::Adjoint, i::Int, j::Int) |
34 | | - return getstoredindex(parent(a), j, i)' |
35 | | -end |
36 | | -function getunstoredindex(a::Adjoint, i::Int, j::Int) |
37 | | - return getunstoredindex(parent(a), j, i)' |
38 | | -end |
39 | | -function setstoredindex!(a::Adjoint, value, i::Int, j::Int) |
40 | | - setstoredindex!(parent(a), value', j, i) |
41 | | - return a |
42 | | -end |
43 | | -function setunstoredindex!(a::Adjoint, value, i::Int, j::Int) |
44 | | - setunstoredindex!(parent(a), value', j, i) |
45 | | - return a |
| 71 | + return map(collect(eachstoredindex(parent(a)))) do I |
| 72 | + return CartesianIndex(genperm(I, perm(a))) |
| 73 | + end |
46 | 74 | end |
47 | 75 |
|
48 | | -using LinearAlgebra: Transpose |
49 | | -storedvalues(a::Transpose) = storedvalues(parent(a)) |
50 | | -function isstored(a::Transpose, i::Int, j::Int) |
51 | | - return isstored(parent(a), j, i) |
52 | | -end |
53 | | -function eachstoredindex(a::Transpose) |
54 | | - # TODO: Make lazy with `Iterators.map`. |
55 | | - return map(CartesianIndex ∘ reverse ∘ Tuple, collect(eachstoredindex(parent(a)))) |
56 | | -end |
57 | | -function getstoredindex(a::Transpose, i::Int, j::Int) |
58 | | - return transpose(getstoredindex(parent(a), j, i)) |
59 | | -end |
60 | | -function getunstoredindex(a::Transpose, i::Int, j::Int) |
61 | | - return transpose(getunstoredindex(parent(a), j, i)) |
62 | | -end |
63 | | -function setstoredindex!(a::Transpose, value, i::Int, j::Int) |
64 | | - setstoredindex!(parent(a), transpose(value), j, i) |
65 | | - return a |
66 | | -end |
67 | | -function setunstoredindex!(a::Transpose, value, i::Int, j::Int) |
68 | | - setunstoredindex!(parent(a), transpose(value), j, i) |
69 | | - return a |
| 76 | +for (type, func) in ((:Adjoint, :adjoint), (:Transpose, :transpose)) |
| 77 | + @eval begin |
| 78 | + using LinearAlgebra: $type |
| 79 | + storedvalues(a::$type) = storedvalues(parent(a)) |
| 80 | + function isstored(a::$type, i::Int, j::Int) |
| 81 | + return isstored(parent(a), j, i) |
| 82 | + end |
| 83 | + function eachstoredindex(a::$type) |
| 84 | + # TODO: Make lazy with `Iterators.map`. |
| 85 | + return map(cartesianindex_reverse, collect(eachstoredindex(parent(a)))) |
| 86 | + end |
| 87 | + function getstoredindex(a::$type, i::Int, j::Int) |
| 88 | + return $func(getstoredindex(parent(a), j, i)) |
| 89 | + end |
| 90 | + function getunstoredindex(a::$type, i::Int, j::Int) |
| 91 | + return $func(getunstoredindex(parent(a), j, i)) |
| 92 | + end |
| 93 | + function setstoredindex!(a::$type, value, i::Int, j::Int) |
| 94 | + setstoredindex!(parent(a), $func(value), j, i) |
| 95 | + return a |
| 96 | + end |
| 97 | + function setunstoredindex!(a::$type, value, i::Int, j::Int) |
| 98 | + setunstoredindex!(parent(a), $func(value), j, i) |
| 99 | + return a |
| 100 | + end |
| 101 | + end |
70 | 102 | end |
0 commit comments